|
| 1 | +import safetensors.torch |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +import safetensors |
| 5 | + |
| 6 | +# ============================================================ |
| 7 | +# 0. infinicore 包导入,配置测试用 safetensors 临时存储路径 |
| 8 | +# ============================================================ |
| 9 | +import sys |
| 10 | +import os |
| 11 | + |
| 12 | +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../python/infinicore'))) |
| 13 | + |
| 14 | +save_dir = os.path.join(os.path.dirname(__file__), '../../tmp') |
| 15 | +os.makedirs(save_dir, exist_ok=True) |
| 16 | +save_path = os.path.join(save_dir, "infinicore_parameter_test.safetensors") |
| 17 | + |
| 18 | +# ============================================================ |
| 19 | +# 1. 使用 PyTorch 定义并保存模型(使用 torch.nn.Parameter) |
| 20 | +# ============================================================ |
| 21 | + |
| 22 | +class TorchParameterNet(nn.Module): |
| 23 | + def __init__(self, in_features=10, out_features=5): |
| 24 | + super().__init__() |
| 25 | + self.weight = nn.Parameter(torch.randn(out_features, in_features)) |
| 26 | + self.bias = nn.Parameter(torch.randn(out_features)) |
| 27 | + self.scale = nn.Parameter(torch.ones(1) * 0.5) |
| 28 | + self.register_buffer("offset", torch.tensor(0.1)) |
| 29 | + |
| 30 | + def forward(self, x): |
| 31 | + return (x @ self.weight.t() + self.bias) * self.scale + self.offset |
| 32 | + |
| 33 | + |
| 34 | +# ===== 保存 Torch 模型 ===== |
| 35 | +torch_model = TorchParameterNet() |
| 36 | +torch_state_dict = torch_model.state_dict() |
| 37 | +safetensors.torch.save_file(torch_state_dict, save_path) |
| 38 | +print("✓ PyTorch 模型已保存") |
| 39 | + |
| 40 | +# ============================================================ |
| 41 | +# 2. 使用 torch 方式加载并推理 |
| 42 | +# ============================================================ |
| 43 | + |
| 44 | +torch_model_infer = TorchParameterNet() |
| 45 | +torch_model_infer.load_state_dict(safetensors.torch.load_file(save_path)) |
| 46 | +torch_model_infer.eval() |
| 47 | + |
| 48 | +input = torch.randn(2, 10) |
| 49 | +torch_model_out = torch_model_infer(input) |
| 50 | +print("✓ Torch 输出:", torch_model_out.detach().numpy().mean()) |
| 51 | + |
| 52 | +# ============================================================ |
| 53 | +# 3. 使用 Parameter 加载并推理 |
| 54 | +# ============================================================ |
| 55 | + |
| 56 | +from nn.modules import Module, Parameter |
| 57 | + |
| 58 | +class InfiniCoreParameterNet(Module): |
| 59 | + def __init__(self, in_features=10, out_features=5): |
| 60 | + super().__init__() |
| 61 | + # 使用 Parameter 替代 torch.nn.Parameter |
| 62 | + self.weight = Parameter(torch.randn(out_features, in_features)) |
| 63 | + self.bias = Parameter(torch.randn(out_features)) |
| 64 | + self.scale = Parameter(torch.ones(1) * 0.5) |
| 65 | + self.register_buffer("offset", torch.tensor(0.1)) |
| 66 | + |
| 67 | + def forward(self, x): |
| 68 | + return (x @ self.weight.t() + self.bias) * self.scale + self.offset |
| 69 | + |
| 70 | +# ===== 使用 InfiniCoreParameterNet 读取 safetensors 并推理 ===== |
| 71 | +infinicore_model_infer = InfiniCoreParameterNet() |
| 72 | +infinicore_model_infer.load_state_dict(safetensors.torch.load_file(save_path)) |
| 73 | +infinicore_model_infer.eval() |
| 74 | + |
| 75 | +infinicore_model_out = infinicore_model_infer.forward(input) |
| 76 | +print("✓ InfiniCore 输出:", infinicore_model_out.detach().numpy().mean()) |
| 77 | + |
| 78 | +# ============================================================ |
| 79 | +# 4. 对比结果 |
| 80 | +# ============================================================ |
| 81 | + |
| 82 | +diff = (infinicore_model_out - torch_model_out).abs().max().item() |
| 83 | +print(f"✓ Parameter 与 Torch 最大误差: {diff:.8f}") |
| 84 | +if diff < 1e-9: |
| 85 | + print("✓ Parameter 与 Torch 精度一致.") |
| 86 | +else: |
| 87 | + print("✗ Parameter 与 Torch 精度存在差异.") |
| 88 | + |
| 89 | +# ============================================================ |
| 90 | +# 5. 测试 Parameter 的基本功能 |
| 91 | +# ============================================================ |
| 92 | + |
| 93 | +print("\n=== 测试 Parameter 基本功能 ===") |
| 94 | + |
| 95 | +# 测试 1: 创建 Parameter |
| 96 | +param1 = Parameter(torch.randn(5, 10)) |
| 97 | +print(f"✓ 创建 Parameter,形状: {param1.shape}") |
| 98 | +# 检查是否是 Parameter 类型(可能是 InfiniCoreParameter 的别名) |
| 99 | +from nn.modules.parameter import InfiniCoreParameter |
| 100 | +assert isinstance(param1, (Parameter, InfiniCoreParameter)), "应该是 Parameter 类型" |
| 101 | +assert isinstance(param1, torch.Tensor), "应该是 torch.Tensor 的子类" |
| 102 | + |
| 103 | +# 测试 2: requires_grad |
| 104 | +param2 = Parameter(torch.randn(3, 4), requires_grad=False) |
| 105 | +print(f"✓ 创建 requires_grad=False 的 Parameter: {param2.requires_grad}") |
| 106 | +assert not param2.requires_grad, "requires_grad 应该为 False" |
| 107 | + |
| 108 | +param3 = Parameter(torch.randn(3, 4), requires_grad=True) |
| 109 | +print(f"✓ 创建 requires_grad=True 的 Parameter: {param3.requires_grad}") |
| 110 | +assert param3.requires_grad, "requires_grad 应该为 True" |
| 111 | + |
| 112 | +# 测试 3: 自动注册到 Module |
| 113 | +class TestModule(Module): |
| 114 | + def __init__(self): |
| 115 | + super().__init__() |
| 116 | + self.weight = Parameter(torch.randn(5, 10)) |
| 117 | + self.bias = Parameter(torch.randn(5)) |
| 118 | + |
| 119 | +test_module = TestModule() |
| 120 | +param_count = sum(1 for _ in test_module.parameters()) |
| 121 | +print(f"✓ 自动注册到 Module,参数数量: {param_count}") |
| 122 | +assert param_count == 2, f"应该有 2 个参数,实际为 {param_count}" |
| 123 | + |
| 124 | +# 测试 4: 参数访问 |
| 125 | +assert test_module.weight is not None, "weight 应该可以访问" |
| 126 | +assert test_module.bias is not None, "bias 应该可以访问" |
| 127 | +print("✓ 参数可以通过属性访问") |
| 128 | + |
| 129 | +# 测试 5: state_dict |
| 130 | +state_dict = test_module.state_dict() |
| 131 | +print(f"✓ state_dict 键数量: {len(state_dict)}") |
| 132 | +assert 'weight' in state_dict, "state_dict 应该包含 weight" |
| 133 | +assert 'bias' in state_dict, "state_dict 应该包含 bias" |
| 134 | +print(f"✓ state_dict 键: {list(state_dict.keys())}") |
| 135 | + |
| 136 | +# 测试 6: __repr__ |
| 137 | +repr_str = repr(param1) |
| 138 | +print(f"✓ __repr__ 方法: 输出包含类名") |
| 139 | +assert "Parameter" in repr_str or "InfiniCoreParameter" in repr_str, "repr 应该包含类名" |
| 140 | +print(repr_str[:100] + "...") |
| 141 | + |
| 142 | +# 测试 7: 与 torch.nn.Parameter 兼容性 |
| 143 | +class MixedModule(Module): |
| 144 | + def __init__(self): |
| 145 | + super().__init__() |
| 146 | + self.torch_param = nn.Parameter(torch.randn(3, 4)) |
| 147 | + self.infinicore_param = Parameter(torch.randn(3, 4)) |
| 148 | + |
| 149 | +mixed_module = MixedModule() |
| 150 | +mixed_param_count = sum(1 for _ in mixed_module.parameters()) |
| 151 | +print(f"✓ 混合使用 torch.nn.Parameter 和 Parameter,参数数量: {mixed_param_count}") |
| 152 | +assert mixed_param_count == 2, f"应该有 2 个参数,实际为 {mixed_param_count}" |
| 153 | + |
| 154 | +# 测试 8: 前向传播 |
| 155 | +class TestModuleWithForward(Module): |
| 156 | + def __init__(self): |
| 157 | + super().__init__() |
| 158 | + self.weight = Parameter(torch.randn(5, 10)) |
| 159 | + self.bias = Parameter(torch.randn(5)) |
| 160 | + |
| 161 | + def forward(self, x): |
| 162 | + return x @ self.weight.t() + self.bias |
| 163 | + |
| 164 | +test_module_forward = TestModuleWithForward() |
| 165 | +test_input = torch.randn(2, 10) |
| 166 | +with torch.no_grad(): |
| 167 | + output = test_module_forward.forward(test_input) |
| 168 | +print(f"✓ 前向传播成功,输出形状: {output.shape}") |
| 169 | +assert output.shape == (2, 5), f"输出形状应该是 (2, 5),实际为 {output.shape}" |
| 170 | + |
| 171 | +# 测试 9: 从 None 创建 |
| 172 | +param_empty = Parameter(None) |
| 173 | +print(f"✓ 从 None 创建 Parameter,形状: {param_empty.shape}") |
| 174 | +assert param_empty.shape == torch.Size([0]), "从 None 创建应该是空张量" |
| 175 | + |
| 176 | +# 测试 10: 深拷贝 |
| 177 | +import copy |
| 178 | +param_copy = copy.deepcopy(param1) |
| 179 | +print(f"✓ 深拷贝 Parameter,形状: {param_copy.shape}") |
| 180 | +assert param_copy.shape == param1.shape, "深拷贝后形状应该相同" |
| 181 | +assert not torch.equal(param_copy, param1) or id(param_copy) != id(param1), "深拷贝应该是新对象" |
| 182 | + |
| 183 | +print("\n=== 所有测试通过! ===") |
| 184 | + |
0 commit comments