Skip to content

Commit f9492c1

Browse files
zhuyuegongchensu
authored andcommitted
feat: add infinicore.nn.InfiniCoreParameter referencing torch.nn.Parameter and tests.
1 parent 160fd18 commit f9492c1

File tree

4 files changed

+330
-8
lines changed

4 files changed

+330
-8
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .module import InfiniCoreModule as Module
22
from .module_list import InfiniCoreModuleList as ModuleList
3+
from .parameter import InfiniCoreParameter as Parameter

python/infinicore/nn/modules/module.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class InfiniCoreModule:
4646
_version: int = 1
4747

4848
training: bool
49-
_parameters: Dict[str, Optional[torch.nn.Parameter]]
49+
_parameters: Dict[str, Optional[Union[torch.nn.Parameter, 'InfiniCoreParameter']]]
5050
_buffers: Dict[str, Optional[torch.Tensor]]
5151
_non_persistent_buffers_set: Set[str]
5252
_modules: Dict[str, Optional['InfiniCoreModule']]
@@ -84,7 +84,9 @@ def remove_from(*dicts_or_sets) -> None:
8484
d.discard(name)
8585

8686
params = self.__dict__.get("_parameters")
87-
if isinstance(value, torch.nn.Parameter):
87+
# Support both torch.nn.Parameter and InfiniCoreParameter
88+
from .parameter import InfiniCoreParameter
89+
if isinstance(value, (torch.nn.Parameter, InfiniCoreParameter)):
8890
if params is None:
8991
raise AttributeError(
9092
"cannot assign parameters before Module.__init__() call"
@@ -100,7 +102,7 @@ def remove_from(*dicts_or_sets) -> None:
100102
if value is not None:
101103
raise TypeError(
102104
f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
103-
"(torch.nn.Parameter or None expected)"
105+
"(torch.nn.Parameter, InfiniCoreParameter or None expected)"
104106
)
105107
self.register_parameter(name, value)
106108
else:
@@ -239,12 +241,14 @@ def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) ->
239241

240242
if param is None:
241243
self._parameters[name] = None
242-
elif not isinstance(param, torch.nn.Parameter):
243-
raise TypeError(
244-
f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "
245-
"(torch.nn.Parameter or None required)"
246-
)
247244
else:
245+
# Support both torch.nn.Parameter and InfiniCoreParameter
246+
from .parameter import InfiniCoreParameter
247+
if not isinstance(param, (torch.nn.Parameter, InfiniCoreParameter)):
248+
raise TypeError(
249+
f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "
250+
"(torch.nn.Parameter, InfiniCoreParameter or None required)"
251+
)
248252
self._parameters[name] = param
249253

250254
def get_extra_state(self) -> Any:
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) 2025, InfiniCore
2+
#
3+
# This file contains modified code derived from PyTorch's `torch.nn.Parameter`
4+
# implementation, which is licensed under the BSD 3-Clause License.
5+
#
6+
# The modifications include adaptations for the InfiniCore framework.
7+
#
8+
# Original PyTorch source:
9+
# https://github.com/pytorch/pytorch/blob/main/torch/nn/parameter.py
10+
#
11+
# Referencing PyTorch v2.4.0
12+
#
13+
# The use of this file is governed by the BSD 3-Clause License.
14+
15+
import torch
16+
from typing import Optional
17+
from collections import OrderedDict
18+
19+
20+
class InfiniCoreParameter(torch.Tensor):
21+
r"""A kind of Tensor that is to be considered a module parameter.
22+
23+
Parameters are :class:`~torch.Tensor` subclasses, that have a
24+
very special property when used with :class:`InfiniCoreModule` s - when they're
25+
assigned as Module attributes they are automatically added to the list of
26+
its parameters, and will appear e.g. in :meth:`~InfiniCoreModule.parameters` iterator.
27+
28+
Assigning a Tensor doesn't have such effect. This is because one might
29+
want to cache some temporary state, like last hidden state of the RNN, in
30+
the model. If there was no such class as :class:`InfiniCoreParameter`, these
31+
temporaries would get registered too.
32+
33+
Args:
34+
data (Tensor, optional): parameter tensor. If None, creates an empty tensor.
35+
requires_grad (bool, optional): if the parameter requires gradient. Note that
36+
the torch.no_grad() context does NOT affect the default behavior of
37+
Parameter creation--the Parameter will still have `requires_grad=True` in
38+
:class:`~no_grad` mode. See :ref:`locally-disable-grad-doc` for more
39+
details. Default: `True`
40+
41+
Example::
42+
43+
>>> import torch
44+
>>> from infinicore.nn.modules import InfiniCoreModule, InfiniCoreParameter
45+
>>>
46+
>>> class MyModule(InfiniCoreModule):
47+
... def __init__(self):
48+
... super().__init__()
49+
... self.weight = InfiniCoreParameter(torch.randn(10, 5))
50+
... self.bias = InfiniCoreParameter(torch.randn(5))
51+
...
52+
>>> module = MyModule()
53+
>>> for param in module.parameters():
54+
... print(param.shape)
55+
torch.Size([10, 5])
56+
torch.Size([5])
57+
"""
58+
59+
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True):
60+
if data is None:
61+
data = torch.empty(0)
62+
63+
# Handle standard torch.Tensor or InfiniCoreParameter
64+
if type(data) is torch.Tensor or type(data) is InfiniCoreParameter:
65+
# For ease of BC maintenance, keep this path for standard Tensor.
66+
# Eventually (tm), we should change the behavior for standard Tensor to match.
67+
return torch.Tensor._make_subclass(cls, data, requires_grad)
68+
69+
# Path for custom tensors: set a flag on the instance to indicate parameter-ness.
70+
t = data.detach().requires_grad_(requires_grad)
71+
72+
if type(t) is not type(data):
73+
raise RuntimeError(
74+
f"Creating a InfiniCoreParameter from an instance of type {type(data).__name__} "
75+
"requires that detach() returns an instance of the same type, but return "
76+
f"type {type(t).__name__} was found instead. To use the type as a "
77+
"InfiniCoreParameter, please correct the detach() semantics defined by "
78+
"its __torch_dispatch__() implementation."
79+
)
80+
81+
t._is_param = True
82+
return t
83+
84+
# Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types
85+
# are still considered that custom tensor type and these methods will not be called for them.
86+
87+
def __deepcopy__(self, memo):
88+
if id(self) in memo:
89+
return memo[id(self)]
90+
else:
91+
result = type(self)(
92+
self.data.clone(memory_format=torch.preserve_format), self.requires_grad
93+
)
94+
memo[id(self)] = result
95+
return result
96+
97+
def __repr__(self):
98+
return "InfiniCoreParameter containing:\n" + super().__repr__()
99+
100+
def __reduce_ex__(self, proto):
101+
# Simplified version for serialization
102+
# In a full implementation, you might want to handle hooks and state
103+
state = getattr(self, '_state', None)
104+
hooks = OrderedDict()
105+
106+
if not state:
107+
return (
108+
_rebuild_parameter,
109+
(self.data, self.requires_grad, hooks),
110+
)
111+
return (
112+
_rebuild_parameter_with_state,
113+
(self.data, self.requires_grad, hooks, state),
114+
)
115+
116+
# Note: __torch_function__ is handled by the Tensor base class
117+
# We don't need to override it for standard Parameter behavior
118+
119+
120+
def _rebuild_parameter(data, requires_grad, hooks):
121+
"""Rebuild a parameter from serialized data."""
122+
param = InfiniCoreParameter(data, requires_grad)
123+
# Apply hooks if any (simplified - full implementation would restore hooks)
124+
return param
125+
126+
127+
def _rebuild_parameter_with_state(data, requires_grad, hooks, state):
128+
"""Rebuild a parameter with extra state from serialized data."""
129+
param = InfiniCoreParameter(data, requires_grad)
130+
param._state = state
131+
# Apply hooks if any (simplified - full implementation would restore hooks)
132+
return param
133+
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import safetensors.torch
2+
import torch
3+
import torch.nn as nn
4+
import safetensors
5+
6+
# ============================================================
7+
# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径
8+
# ============================================================
9+
import sys
10+
import os
11+
12+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../python/infinicore')))
13+
14+
save_dir = os.path.join(os.path.dirname(__file__), '../../tmp')
15+
os.makedirs(save_dir, exist_ok=True)
16+
save_path = os.path.join(save_dir, "infinicore_parameter_test.safetensors")
17+
18+
# ============================================================
19+
# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.Parameter)
20+
# ============================================================
21+
22+
class TorchParameterNet(nn.Module):
23+
def __init__(self, in_features=10, out_features=5):
24+
super().__init__()
25+
self.weight = nn.Parameter(torch.randn(out_features, in_features))
26+
self.bias = nn.Parameter(torch.randn(out_features))
27+
self.scale = nn.Parameter(torch.ones(1) * 0.5)
28+
self.register_buffer("offset", torch.tensor(0.1))
29+
30+
def forward(self, x):
31+
return (x @ self.weight.t() + self.bias) * self.scale + self.offset
32+
33+
34+
# ===== 保存 Torch 模型 =====
35+
torch_model = TorchParameterNet()
36+
torch_state_dict = torch_model.state_dict()
37+
safetensors.torch.save_file(torch_state_dict, save_path)
38+
print("✓ PyTorch 模型已保存")
39+
40+
# ============================================================
41+
# 2. 使用 torch 方式加载并推理
42+
# ============================================================
43+
44+
torch_model_infer = TorchParameterNet()
45+
torch_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
46+
torch_model_infer.eval()
47+
48+
input = torch.randn(2, 10)
49+
torch_model_out = torch_model_infer(input)
50+
print("✓ Torch 输出:", torch_model_out.detach().numpy().mean())
51+
52+
# ============================================================
53+
# 3. 使用 Parameter 加载并推理
54+
# ============================================================
55+
56+
from nn.modules import Module, Parameter
57+
58+
class InfiniCoreParameterNet(Module):
59+
def __init__(self, in_features=10, out_features=5):
60+
super().__init__()
61+
# 使用 Parameter 替代 torch.nn.Parameter
62+
self.weight = Parameter(torch.randn(out_features, in_features))
63+
self.bias = Parameter(torch.randn(out_features))
64+
self.scale = Parameter(torch.ones(1) * 0.5)
65+
self.register_buffer("offset", torch.tensor(0.1))
66+
67+
def forward(self, x):
68+
return (x @ self.weight.t() + self.bias) * self.scale + self.offset
69+
70+
# ===== 使用 InfiniCoreParameterNet 读取 safetensors 并推理 =====
71+
infinicore_model_infer = InfiniCoreParameterNet()
72+
infinicore_model_infer.load_state_dict(safetensors.torch.load_file(save_path))
73+
infinicore_model_infer.eval()
74+
75+
infinicore_model_out = infinicore_model_infer.forward(input)
76+
print("✓ InfiniCore 输出:", infinicore_model_out.detach().numpy().mean())
77+
78+
# ============================================================
79+
# 4. 对比结果
80+
# ============================================================
81+
82+
diff = (infinicore_model_out - torch_model_out).abs().max().item()
83+
print(f"✓ Parameter 与 Torch 最大误差: {diff:.8f}")
84+
if diff < 1e-9:
85+
print("✓ Parameter 与 Torch 精度一致.")
86+
else:
87+
print("✗ Parameter 与 Torch 精度存在差异.")
88+
89+
# ============================================================
90+
# 5. 测试 Parameter 的基本功能
91+
# ============================================================
92+
93+
print("\n=== 测试 Parameter 基本功能 ===")
94+
95+
# 测试 1: 创建 Parameter
96+
param1 = Parameter(torch.randn(5, 10))
97+
print(f"✓ 创建 Parameter,形状: {param1.shape}")
98+
# 检查是否是 Parameter 类型(可能是 InfiniCoreParameter 的别名)
99+
from nn.modules.parameter import InfiniCoreParameter
100+
assert isinstance(param1, (Parameter, InfiniCoreParameter)), "应该是 Parameter 类型"
101+
assert isinstance(param1, torch.Tensor), "应该是 torch.Tensor 的子类"
102+
103+
# 测试 2: requires_grad
104+
param2 = Parameter(torch.randn(3, 4), requires_grad=False)
105+
print(f"✓ 创建 requires_grad=False 的 Parameter: {param2.requires_grad}")
106+
assert not param2.requires_grad, "requires_grad 应该为 False"
107+
108+
param3 = Parameter(torch.randn(3, 4), requires_grad=True)
109+
print(f"✓ 创建 requires_grad=True 的 Parameter: {param3.requires_grad}")
110+
assert param3.requires_grad, "requires_grad 应该为 True"
111+
112+
# 测试 3: 自动注册到 Module
113+
class TestModule(Module):
114+
def __init__(self):
115+
super().__init__()
116+
self.weight = Parameter(torch.randn(5, 10))
117+
self.bias = Parameter(torch.randn(5))
118+
119+
test_module = TestModule()
120+
param_count = sum(1 for _ in test_module.parameters())
121+
print(f"✓ 自动注册到 Module,参数数量: {param_count}")
122+
assert param_count == 2, f"应该有 2 个参数,实际为 {param_count}"
123+
124+
# 测试 4: 参数访问
125+
assert test_module.weight is not None, "weight 应该可以访问"
126+
assert test_module.bias is not None, "bias 应该可以访问"
127+
print("✓ 参数可以通过属性访问")
128+
129+
# 测试 5: state_dict
130+
state_dict = test_module.state_dict()
131+
print(f"✓ state_dict 键数量: {len(state_dict)}")
132+
assert 'weight' in state_dict, "state_dict 应该包含 weight"
133+
assert 'bias' in state_dict, "state_dict 应该包含 bias"
134+
print(f"✓ state_dict 键: {list(state_dict.keys())}")
135+
136+
# 测试 6: __repr__
137+
repr_str = repr(param1)
138+
print(f"✓ __repr__ 方法: 输出包含类名")
139+
assert "Parameter" in repr_str or "InfiniCoreParameter" in repr_str, "repr 应该包含类名"
140+
print(repr_str[:100] + "...")
141+
142+
# 测试 7: 与 torch.nn.Parameter 兼容性
143+
class MixedModule(Module):
144+
def __init__(self):
145+
super().__init__()
146+
self.torch_param = nn.Parameter(torch.randn(3, 4))
147+
self.infinicore_param = Parameter(torch.randn(3, 4))
148+
149+
mixed_module = MixedModule()
150+
mixed_param_count = sum(1 for _ in mixed_module.parameters())
151+
print(f"✓ 混合使用 torch.nn.Parameter 和 Parameter,参数数量: {mixed_param_count}")
152+
assert mixed_param_count == 2, f"应该有 2 个参数,实际为 {mixed_param_count}"
153+
154+
# 测试 8: 前向传播
155+
class TestModuleWithForward(Module):
156+
def __init__(self):
157+
super().__init__()
158+
self.weight = Parameter(torch.randn(5, 10))
159+
self.bias = Parameter(torch.randn(5))
160+
161+
def forward(self, x):
162+
return x @ self.weight.t() + self.bias
163+
164+
test_module_forward = TestModuleWithForward()
165+
test_input = torch.randn(2, 10)
166+
with torch.no_grad():
167+
output = test_module_forward.forward(test_input)
168+
print(f"✓ 前向传播成功,输出形状: {output.shape}")
169+
assert output.shape == (2, 5), f"输出形状应该是 (2, 5),实际为 {output.shape}"
170+
171+
# 测试 9: 从 None 创建
172+
param_empty = Parameter(None)
173+
print(f"✓ 从 None 创建 Parameter,形状: {param_empty.shape}")
174+
assert param_empty.shape == torch.Size([0]), "从 None 创建应该是空张量"
175+
176+
# 测试 10: 深拷贝
177+
import copy
178+
param_copy = copy.deepcopy(param1)
179+
print(f"✓ 深拷贝 Parameter,形状: {param_copy.shape}")
180+
assert param_copy.shape == param1.shape, "深拷贝后形状应该相同"
181+
assert not torch.equal(param_copy, param1) or id(param_copy) != id(param1), "深拷贝应该是新对象"
182+
183+
print("\n=== 所有测试通过! ===")
184+

0 commit comments

Comments
 (0)