@@ -44,6 +44,7 @@ def __init__(self):
4444 def forward (self ):
4545 return infinicore .add (self .a , self .b )
4646
47+
4748infinicore_model_infer = InfiniCoreNet ()
4849# ============================================================
4950# 2. 加载权重
@@ -75,6 +76,98 @@ def forward(self):
7576
7677
7778# ============================================================
78- # 5. to测试,buffer测试
79+ # 5. to测试 - 测试模型在不同设备间的转换
7980# ============================================================
80- # 等待添加
81+ print ("\n " + "=" * 60 )
82+ print ("5. to测试 - 设备转换测试" )
83+ print ("=" * 60 )
84+
85+ # 5.1 打印初始状态
86+ print ("\n 5.1 初始状态:" )
87+ print ("-" * 40 )
88+ print ("Parameters:" )
89+ for name , param in infinicore_model_infer .named_parameters ():
90+ print (f" { name } : shape={ param .shape } , dtype={ param .dtype } , device={ param .device } " )
91+ print ("Buffers:" )
92+ buffers_exist = False
93+ for name , buf in infinicore_model_infer .named_buffers ():
94+ buffers_exist = True
95+ print (f" { name } : shape={ buf .shape } , dtype={ buf .dtype } , device={ buf .device } " )
96+ if not buffers_exist :
97+ print (" (无buffers)" )
98+
99+ # 5.2 测试转换到CUDA设备(使用device对象)
100+ print ("\n 5.2 转换到CUDA设备 (使用 infinicore.device('cuda', 0)):" )
101+ print ("-" * 40 )
102+ target_device_cuda = infinicore .device ("cuda" , 0 )
103+ infinicore_model_infer .to (target_device_cuda )
104+
105+ print ("转换后的Parameters:" )
106+ for name , param in infinicore_model_infer .named_parameters ():
107+ print (f" { name } : shape={ param .shape } , dtype={ param .dtype } , device={ param .device } " )
108+ # 验证设备是否正确转换
109+ assert param .device == target_device_cuda , (
110+ f"参数 { name } 的设备转换失败: 期望 { target_device_cuda } , 实际 { param .device } "
111+ )
112+ if buffers_exist :
113+ print ("转换后的Buffers:" )
114+ for name , buf in infinicore_model_infer .named_buffers ():
115+ print (f" { name } : shape={ buf .shape } , dtype={ buf .dtype } , device={ buf .device } " )
116+ assert buf .device == target_device_cuda , (
117+ f"Buffer { name } 的设备转换失败: 期望 { target_device_cuda } , 实际 { buf .device } "
118+ )
119+ print ("✓ CUDA设备转换验证通过" )
120+
121+ # 5.3 测试转换到CPU设备(使用device对象)
122+ print ("\n 5.3 转换到CPU设备 (使用 infinicore.device('cpu', 0)):" )
123+ print ("-" * 40 )
124+ target_device_cpu = infinicore .device ("cpu" , 0 )
125+ infinicore_model_infer .to (target_device_cpu )
126+
127+ print ("转换后的Parameters:" )
128+ for name , param in infinicore_model_infer .named_parameters ():
129+ print (f" { name } : shape={ param .shape } , dtype={ param .dtype } , device={ param .device } " )
130+ # 验证设备是否正确转换
131+ assert param .device == target_device_cpu , (
132+ f"参数 { name } 的设备转换失败: 期望 { target_device_cpu } , 实际 { param .device } "
133+ )
134+ if buffers_exist :
135+ print ("转换后的Buffers:" )
136+ for name , buf in infinicore_model_infer .named_buffers ():
137+ print (f" { name } : shape={ buf .shape } , dtype={ buf .dtype } , device={ buf .device } " )
138+ assert buf .device == target_device_cpu , (
139+ f"Buffer { name } 的设备转换失败: 期望 { target_device_cpu } , 实际 { buf .device } "
140+ )
141+ print ("✓ CPU设备转换验证通过" )
142+
143+ # 5.4 测试使用字符串参数转换到CUDA设备
144+ print ("\n 5.4 转换到CUDA设备 (使用字符串 'cuda'):" )
145+ print ("-" * 40 )
146+ infinicore_model_infer .to ("cuda" )
147+
148+ print ("转换后的Parameters:" )
149+ for name , param in infinicore_model_infer .named_parameters ():
150+ print (f" { name } : shape={ param .shape } , dtype={ param .dtype } , device={ param .device } " )
151+ # 验证设备是否正确转换(字符串'cuda'会被转换为cuda设备)
152+ assert param .device .type == "cuda" , (
153+ f"参数 { name } 的设备转换失败: 期望 cuda, 实际 { param .device .type } "
154+ )
155+ if buffers_exist :
156+ print ("转换后的Buffers:" )
157+ for name , buf in infinicore_model_infer .named_buffers ():
158+ print (f" { name } : shape={ buf .shape } , dtype={ buf .dtype } , device={ buf .device } " )
159+ assert buf .device .type == "cuda" , (
160+ f"Buffer { name } 的设备转换失败: 期望 cuda, 实际 { buf .device .type } "
161+ )
162+ print ("✓ 字符串参数设备转换验证通过" )
163+
164+ # 5.5 验证to方法返回self(链式调用支持)
165+ print ("\n 5.5 测试to方法的返回值(链式调用):" )
166+ print ("-" * 40 )
167+ result = infinicore_model_infer .to (infinicore .device ("cpu" , 0 ))
168+ assert result is infinicore_model_infer , "to方法应该返回self以支持链式调用"
169+ print ("✓ to方法返回值验证通过" )
170+
171+ print ("\n " + "=" * 60 )
172+ print ("所有to测试通过!" )
173+ print ("=" * 60 + "\n " )
0 commit comments