Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions python/infinicore/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import infinicore

from ...device import device as InfiniCoreDevice
from ...tensor import Tensor
from ..parameter import InfiniCoreParameter as Parameter

Expand Down Expand Up @@ -481,15 +482,14 @@ def _load_from_state_dict(
f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}"
)

if (
(param.shape == input_param.shape)
and (param.dtype == input_param.dtype)
and (param.device == input_param.device)
if (param.shape == input_param.shape) and (
param.dtype == input_param.dtype
):
param.copy_(input_param)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为,infinicore的两个tensor之间的copy操作,是支持从 cpu直接拷贝到gpu的。

现在的权重加载判断是: 模型weight和权重文件,二者shape和dtype同一样时,可以拷贝数据,否则报错。

else:
print(f"param '{name}' don't match input_param '{key}'")
setattr(self, name, input_param)
raise KeyError(
f"param '{name}' don't match input_param '{key}' with shape or dtype"
)

elif strict:
missing_keys.append(key)
Expand Down Expand Up @@ -848,10 +848,29 @@ def eval(self: T) -> T:
Returns:
Module: self
"""
pass
raise KeyError("not support")

def _apply(self, fn, recurse=True):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to函数的部分参考了torch的写法

raise KeyError("not support")
if recurse:
for module in self.children():
module._apply(fn)

def to(self, *args, **kwargs):
raise KeyError("not support")
for key, param in self._parameters.items():
if param is not None:
setattr(self, key, fn(param))
Copy link
Collaborator Author

@pengcheng888 pengcheng888 Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用=符号,赋值不成功,不知为何。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最后用了setattr(self, key, fn(param))


for key, buf in self._buffers.items():
if buf is not None:
setattr(self, key, fn(buf))

return self

def to(self, device: str | InfiniCoreDevice):
if device is None:
raise ValueError("device cannot be None")
device = InfiniCoreDevice(device)

def convert(t):
return t.to(device)

return self._apply(convert)
97 changes: 95 additions & 2 deletions test/infinicore/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self):
def forward(self):
return infinicore.add(self.a, self.b)


infinicore_model_infer = InfiniCoreNet()
# ============================================================
# 2. 加载权重
Expand Down Expand Up @@ -75,6 +76,98 @@ def forward(self):


# ============================================================
# 5. to测试,buffer测试
# 5. to测试 - 测试模型在不同设备间的转换
# ============================================================
# 等待添加
print("\n" + "=" * 60)
print("5. to测试 - 设备转换测试")
print("=" * 60)

# 5.1 打印初始状态
print("\n5.1 初始状态:")
print("-" * 40)
print("Parameters:")
for name, param in infinicore_model_infer.named_parameters():
print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
print("Buffers:")
buffers_exist = False
for name, buf in infinicore_model_infer.named_buffers():
buffers_exist = True
print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}")
if not buffers_exist:
print(" (无buffers)")

# 5.2 测试转换到CUDA设备(使用device对象)
print("\n5.2 转换到CUDA设备 (使用 infinicore.device('cuda', 0)):")
print("-" * 40)
target_device_cuda = infinicore.device("cuda", 0)
infinicore_model_infer.to(target_device_cuda)

print("转换后的Parameters:")
for name, param in infinicore_model_infer.named_parameters():
print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
# 验证设备是否正确转换
assert param.device == target_device_cuda, (
f"参数 {name} 的设备转换失败: 期望 {target_device_cuda}, 实际 {param.device}"
)
if buffers_exist:
print("转换后的Buffers:")
for name, buf in infinicore_model_infer.named_buffers():
print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}")
assert buf.device == target_device_cuda, (
f"Buffer {name} 的设备转换失败: 期望 {target_device_cuda}, 实际 {buf.device}"
)
print("✓ CUDA设备转换验证通过")

# 5.3 测试转换到CPU设备(使用device对象)
print("\n5.3 转换到CPU设备 (使用 infinicore.device('cpu', 0)):")
print("-" * 40)
target_device_cpu = infinicore.device("cpu", 0)
infinicore_model_infer.to(target_device_cpu)

print("转换后的Parameters:")
for name, param in infinicore_model_infer.named_parameters():
print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
# 验证设备是否正确转换
assert param.device == target_device_cpu, (
f"参数 {name} 的设备转换失败: 期望 {target_device_cpu}, 实际 {param.device}"
)
if buffers_exist:
print("转换后的Buffers:")
for name, buf in infinicore_model_infer.named_buffers():
print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}")
assert buf.device == target_device_cpu, (
f"Buffer {name} 的设备转换失败: 期望 {target_device_cpu}, 实际 {buf.device}"
)
print("✓ CPU设备转换验证通过")

# 5.4 测试使用字符串参数转换到CUDA设备
print("\n5.4 转换到CUDA设备 (使用字符串 'cuda'):")
print("-" * 40)
infinicore_model_infer.to("cuda")

print("转换后的Parameters:")
for name, param in infinicore_model_infer.named_parameters():
print(f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
# 验证设备是否正确转换(字符串'cuda'会被转换为cuda设备)
assert param.device.type == "cuda", (
f"参数 {name} 的设备转换失败: 期望 cuda, 实际 {param.device.type}"
)
if buffers_exist:
print("转换后的Buffers:")
for name, buf in infinicore_model_infer.named_buffers():
print(f" {name}: shape={buf.shape}, dtype={buf.dtype}, device={buf.device}")
assert buf.device.type == "cuda", (
f"Buffer {name} 的设备转换失败: 期望 cuda, 实际 {buf.device.type}"
)
print("✓ 字符串参数设备转换验证通过")

# 5.5 验证to方法返回self(链式调用支持)
print("\n5.5 测试to方法的返回值(链式调用):")
print("-" * 40)
result = infinicore_model_infer.to(infinicore.device("cpu", 0))
assert result is infinicore_model_infer, "to方法应该返回self以支持链式调用"
print("✓ to方法返回值验证通过")

print("\n" + "=" * 60)
print("所有to测试通过!")
print("=" * 60 + "\n")