Skip to content

Commit dfb1e8d

Browse files
author
pengcheng888
committed
issue/890 - 为python端的nn.module添加to函数
1 parent 3b5afff commit dfb1e8d

File tree

2 files changed

+124
-12
lines changed

2 files changed

+124
-12
lines changed

python/infinicore/nn/modules/module.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import infinicore
3434

35+
from ...device import device as InfiniCoreDevice
3536
from ...tensor import Tensor
3637
from ..parameter import InfiniCoreParameter as Parameter
3738

@@ -481,15 +482,14 @@ def _load_from_state_dict(
481482
f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}"
482483
)
483484

484-
if (
485-
(param.shape == input_param.shape)
486-
and (param.dtype == input_param.dtype)
487-
and (param.device == input_param.device)
485+
if (param.shape == input_param.shape) and (
486+
param.dtype == input_param.dtype
488487
):
489488
param.copy_(input_param)
490489
else:
491-
print(f"param '{name}' don't match input_param '{key}'")
492-
setattr(self, name, input_param)
490+
raise KeyError(
491+
f"param '{name}' don't match input_param '{key}' with shape or dtype"
492+
)
493493

494494
elif strict:
495495
missing_keys.append(key)
@@ -848,10 +848,29 @@ def eval(self: T) -> T:
848848
Returns:
849849
Module: self
850850
"""
851-
pass
851+
raise KeyError("not support")
852852

853853
def _apply(self, fn, recurse=True):
854-
raise KeyError("not support")
854+
if recurse:
855+
for module in self.children():
856+
module._apply(fn)
855857

856-
def to(self, *args, **kwargs):
857-
raise KeyError("not support")
858+
for key, param in self._parameters.items():
859+
if param is not None:
860+
setattr(self, key, fn(param))
861+
862+
for key, buf in self._buffers.items():
863+
if buf is not None:
864+
setattr(self, key, fn(buf))
865+
866+
return self
867+
868+
def to(self, device: str | InfiniCoreDevice):
869+
if device is None:
870+
raise ValueError("device cannot be None")
871+
device = InfiniCoreDevice(device)
872+
873+
def convert(t):
874+
return t.to(device)
875+
876+
return self._apply(convert)

test/infinicore/nn/module.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self):
4444
def forward(self):
4545
return infinicore.add(self.a, self.b)
4646

47+
4748
infinicore_model_infer = InfiniCoreNet()
4849
# ============================================================
4950
# 2. 加载权重
@@ -75,6 +76,98 @@ def forward(self):
7576

7677

7778
# ============================================================
78-
# 5. to测试,buffer测试
79+
# 5. to测试 - 测试模型在不同设备间的转换
7980
# ============================================================
80-
# 等待添加
81+
print("\n" + "=" * 60)
82+
print("5. to测试 - 设备转换测试")
83+
print("=" * 60)
84+
85+
# 5.1 打印初始状态
86+
print("\n5.1 初始状态:")
87+
print("-" * 40)
88+
print("Parameters:")
89+
for name, param in infinicore_model_infer.named_parameters():
90+
print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
91+
print("Buffers:")
92+
buffers_exist = False
93+
for name, buf in infinicore_model_infer.named_buffers():
94+
buffers_exist = True
95+
print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}")
96+
if not buffers_exist:
97+
print(" (无buffers)")
98+
99+
# 5.2 测试转换到CUDA设备(使用device对象)
100+
print("\n5.2 转换到CUDA设备 (使用 infinicore.device('cuda', 0)):")
101+
print("-" * 40)
102+
target_device_cuda = infinicore.device("cuda", 0)
103+
infinicore_model_infer.to(target_device_cuda)
104+
105+
print("转换后的Parameters:")
106+
for name, param in infinicore_model_infer.named_parameters():
107+
print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
108+
# 验证设备是否正确转换
109+
assert param.device == target_device_cuda, (
110+
f"参数 {name} 的设备转换失败: 期望 {target_device_cuda}, 实际 {param.device}"
111+
)
112+
if buffers_exist:
113+
print("转换后的Buffers:")
114+
for name, buf in infinicore_model_infer.named_buffers():
115+
print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}")
116+
assert buf.device == target_device_cuda, (
117+
f"Buffer {name} 的设备转换失败: 期望 {target_device_cuda}, 实际 {buf.device}"
118+
)
119+
print("✓ CUDA设备转换验证通过")
120+
121+
# 5.3 测试转换到CPU设备(使用device对象)
122+
print("\n5.3 转换到CPU设备 (使用 infinicore.device('cpu', 0)):")
123+
print("-" * 40)
124+
target_device_cpu = infinicore.device("cpu", 0)
125+
infinicore_model_infer.to(target_device_cpu)
126+
127+
print("转换后的Parameters:")
128+
for name, param in infinicore_model_infer.named_parameters():
129+
print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
130+
# 验证设备是否正确转换
131+
assert param.device == target_device_cpu, (
132+
f"参数 {name} 的设备转换失败: 期望 {target_device_cpu}, 实际 {param.device}"
133+
)
134+
if buffers_exist:
135+
print("转换后的Buffers:")
136+
for name, buf in infinicore_model_infer.named_buffers():
137+
print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}")
138+
assert buf.device == target_device_cpu, (
139+
f"Buffer {name} 的设备转换失败: 期望 {target_device_cpu}, 实际 {buf.device}"
140+
)
141+
print("✓ CPU设备转换验证通过")
142+
143+
# 5.4 测试使用字符串参数转换到CUDA设备
144+
print("\n5.4 转换到CUDA设备 (使用字符串 'cuda'):")
145+
print("-" * 40)
146+
infinicore_model_infer.to("cuda")
147+
148+
print("转换后的Parameters:")
149+
for name, param in infinicore_model_infer.named_parameters():
150+
print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
151+
# 验证设备是否正确转换(字符串'cuda'会被转换为cuda设备)
152+
assert param.device.type == "cuda", (
153+
f"参数 {name} 的设备转换失败: 期望 cuda, 实际 {param.device.type}"
154+
)
155+
if buffers_exist:
156+
print("转换后的Buffers:")
157+
for name, buf in infinicore_model_infer.named_buffers():
158+
print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}")
159+
assert buf.device.type == "cuda", (
160+
f"Buffer {name} 的设备转换失败: 期望 cuda, 实际 {buf.device.type}"
161+
)
162+
print("✓ 字符串参数设备转换验证通过")
163+
164+
# 5.5 验证to方法返回self(链式调用支持)
165+
print("\n5.5 测试to方法的返回值(链式调用):")
166+
print("-" * 40)
167+
result = infinicore_model_infer.to(infinicore.device("cpu", 0))
168+
assert result is infinicore_model_infer, "to方法应该返回self以支持链式调用"
169+
print("✓ to方法返回值验证通过")
170+
171+
print("\n" + "=" * 60)
172+
print("所有to测试通过!")
173+
print("=" * 60 + "\n")

0 commit comments

Comments
 (0)