diff --git a/python/infinicore/nn/modules/module.py b/python/infinicore/nn/modules/module.py index d21223903..013f8331d 100644 --- a/python/infinicore/nn/modules/module.py +++ b/python/infinicore/nn/modules/module.py @@ -32,6 +32,7 @@ import infinicore +from ...device import device as InfiniCoreDevice from ...tensor import Tensor from ..parameter import InfiniCoreParameter as Parameter @@ -481,15 +482,14 @@ def _load_from_state_dict( f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}" ) - if ( - (param.shape == input_param.shape) - and (param.dtype == input_param.dtype) - and (param.device == input_param.device) + if (param.shape == input_param.shape) and ( + param.dtype == input_param.dtype ): param.copy_(input_param) else: - print(f"param '{name}' don't match input_param '{key}'") - setattr(self, name, input_param) + raise KeyError( + f"param '{name}' don't match input_param '{key}' with shape or dtype" + ) elif strict: missing_keys.append(key) @@ -848,10 +848,29 @@ def eval(self: T) -> T: Returns: Module: self """ - pass + raise KeyError("not support") def _apply(self, fn, recurse=True): - raise KeyError("not support") + if recurse: + for module in self.children(): + module._apply(fn) - def to(self, *args, **kwargs): - raise KeyError("not support") + for key, param in self._parameters.items(): + if param is not None: + setattr(self, key, fn(param)) + + for key, buf in self._buffers.items(): + if buf is not None: + setattr(self, key, fn(buf)) + + return self + + def to(self, device: str | InfiniCoreDevice): + if device is None: + raise ValueError("device cannot be None") + device = InfiniCoreDevice(device) + + def convert(t): + return t.to(device) + + return self._apply(convert) diff --git a/test/infinicore/nn/module.py b/test/infinicore/nn/module.py index 69e341fa2..4abaeba3b 100644 --- a/test/infinicore/nn/module.py +++ b/test/infinicore/nn/module.py @@ -44,6 +44,7 @@ def __init__(self): def forward(self): return infinicore.add(self.a, self.b) + infinicore_model_infer = InfiniCoreNet() # ============================================================ # 2. 加载权重 @@ -75,6 +76,98 @@ def forward(self): # ============================================================ -# 5. to测试,buffer测试 +# 5. to测试 - 测试模型在不同设备间的转换 # ============================================================ -# 等待添加 +print("\n" + "=" * 60) +print("5. to测试 - 设备转换测试") +print("=" * 60) + +# 5.1 打印初始状态 +print("\n5.1 初始状态:") +print("-" * 40) +print("Parameters:") +for name, param in infinicore_model_infer.named_parameters(): + print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}") +print("Buffers:") +buffers_exist = False +for name, buf in infinicore_model_infer.named_buffers(): + buffers_exist = True + print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}") +if not buffers_exist: + print(" (无buffers)") + +# 5.2 测试转换到CUDA设备(使用device对象) +print("\n5.2 转换到CUDA设备 (使用 infinicore.device('cuda', 0)):") +print("-" * 40) +target_device_cuda = infinicore.device("cuda", 0) +infinicore_model_infer.to(target_device_cuda) + +print("转换后的Parameters:") +for name, param in infinicore_model_infer.named_parameters(): + print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}") + # 验证设备是否正确转换 + assert param.device == target_device_cuda, ( + f"参数 {name} 的设备转换失败: 期望 {target_device_cuda}, 实际 {param.device}" + ) +if buffers_exist: + print("转换后的Buffers:") + for name, buf in infinicore_model_infer.named_buffers(): + print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}") + assert buf.device == target_device_cuda, ( + f"Buffer {name} 的设备转换失败: 期望 {target_device_cuda}, 实际 {buf.device}" + ) +print("✓ CUDA设备转换验证通过") + +# 5.3 测试转换到CPU设备(使用device对象) +print("\n5.3 转换到CPU设备 (使用 infinicore.device('cpu', 0)):") +print("-" * 40) +target_device_cpu = infinicore.device("cpu", 0) +infinicore_model_infer.to(target_device_cpu) + +print("转换后的Parameters:") +for name, param in infinicore_model_infer.named_parameters(): + print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}") + # 验证设备是否正确转换 + assert param.device == target_device_cpu, ( + f"参数 {name} 的设备转换失败: 期望 {target_device_cpu}, 实际 {param.device}" + ) +if buffers_exist: + print("转换后的Buffers:") + for name, buf in infinicore_model_infer.named_buffers(): + print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}") + assert buf.device == target_device_cpu, ( + f"Buffer {name} 的设备转换失败: 期望 {target_device_cpu}, 实际 {buf.device}" + ) +print("✓ CPU设备转换验证通过") + +# 5.4 测试使用字符串参数转换到CUDA设备 +print("\n5.4 转换到CUDA设备 (使用字符串 'cuda'):") +print("-" * 40) +infinicore_model_infer.to("cuda") + +print("转换后的Parameters:") +for name, param in infinicore_model_infer.named_parameters(): + print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}") + # 验证设备是否正确转换(字符串'cuda'会被转换为cuda设备) + assert param.device.type == "cuda", ( + f"参数 {name} 的设备转换失败: 期望 cuda, 实际 {param.device.type}" + ) +if buffers_exist: + print("转换后的Buffers:") + for name, buf in infinicore_model_infer.named_buffers(): + print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}") + assert buf.device.type == "cuda", ( + f"Buffer {name} 的设备转换失败: 期望 cuda, 实际 {buf.device.type}" + ) +print("✓ 字符串参数设备转换验证通过") + +# 5.5 验证to方法返回self(链式调用支持) +print("\n5.5 测试to方法的返回值(链式调用):") +print("-" * 40) +result = infinicore_model_infer.to(infinicore.device("cpu", 0)) +assert result is infinicore_model_infer, "to方法应该返回self以支持链式调用" +print("✓ to方法返回值验证通过") + +print("\n" + "=" * 60) +print("所有to测试通过!") +print("=" * 60 + "\n")