1515
1616import infinicore
1717import torch
18- import ctypes
1918
2019
2120def test_embedding_graph_recording ():
2221 """测试 embedding 是否支持 CUDA Graph 录制"""
2322 print ("=" * 60 )
2423 print ("测试 Embedding 图录制支持" )
2524 print ("=" * 60 )
26-
25+
2726 # 检查是否有 CUDA
2827 if not torch .cuda .is_available ():
2928 print ("⚠ CUDA 不可用,跳过图录制测试" )
3029 return False
31-
30+
3231 device = infinicore .device ("cuda" , 0 )
33-
32+
3433 # 创建 embedding 模块
3534 vocab_size = 1000
3635 embedding_dim = 128
3736 embedding = infinicore .nn .Embedding (
3837 num_embeddings = vocab_size ,
3938 embedding_dim = embedding_dim ,
4039 dtype = infinicore .float32 ,
41- device = device
40+ device = device ,
4241 )
43-
42+
4443 # 创建设备端的 input_ids(这是关键:改动前不支持,改动后支持)
4544 batch_size = 4
4645 seq_len = 32
4746 input_ids_device = infinicore .from_list (
4847 [[i % vocab_size for i in range (seq_len )] for _ in range (batch_size )],
4948 dtype = infinicore .int64 ,
50- device = device
49+ device = device ,
5150 )
52-
51+
5352 print (f"\n 1. 输入张量信息:" )
5453 print (f" - Shape: { input_ids_device .shape } " )
5554 print (f" - Device: { input_ids_device .device .type } " )
5655 print (f" - Dtype: { input_ids_device .dtype } " )
57-
56+
5857 # 尝试使用 CUDA Graph 录制
5958 print (f"\n 2. 尝试 CUDA Graph 录制..." )
60-
59+
6160 # 使用 PyTorch 的 CUDA Graph API 进行测试(更简单可靠)
6261 try :
6362 # 设置设备
6463 infinicore .set_device (device )
65-
64+
6665 # 使用 PyTorch 的 CUDA Graph API
6766 # 注意:PyTorch 2.0+ 支持 torch.cuda.graph
6867 try :
6968 # 方法 1: 使用 PyTorch 的 CUDA Graph(推荐)
7069 print (" 使用 PyTorch CUDA Graph API 测试..." )
71-
70+
7271 # 创建 warmup 输入
7372 warmup_input = input_ids_device
74-
73+
7574 # Warmup(图录制前需要先执行一次,包括内存分配)
76- warmup_output = embedding .forward (warmup_input )
75+ embedding .forward (warmup_input )
7776 infinicore .sync_stream () # 同步确保 warmup 完成
78-
77+
7978 # 预先分配输出张量(CUDA Graph 不支持动态内存分配)
8079 # 输出形状: input_shape + [embedding_dim]
8180 output_shape = list (input_ids_device .shape ) + [embedding_dim ]
8281 output = infinicore .empty (
83- output_shape ,
84- dtype = embedding .weight .dtype ,
85- device = device
82+ output_shape , dtype = embedding .weight .dtype , device = device
8683 )
87-
84+
8885 # Warmup embedding(确保内存分配完成)
8986 import infinicore .nn .functional as F
87+
9088 F .embedding (warmup_input , embedding .weight , out = output )
9189 infinicore .sync_stream ()
92-
90+
9391 # 开始图录制(使用预先分配的 output)
9492 graph = torch .cuda .CUDAGraph ()
9593 with torch .cuda .graph (graph ):
9694 # 使用 embedding 的 out 参数(in-place),传入预先分配的 output
9795 F .embedding (input_ids_device , embedding .weight , out = output )
98-
96+
9997 print (" ✓ 成功完成图录制!" )
10098 print (" ✓ Embedding 支持 CUDA Graph 录制" )
101-
99+
102100 # 验证图可以重复执行
103101 graph .replay ()
104102 infinicore .sync_stream ()
105-
103+
106104 print (" ✓ 图可以成功重放" )
107105 return True
108-
106+
109107 except AttributeError :
110108 # PyTorch 版本可能不支持 torch.cuda.graph
111109 print (" ⚠ PyTorch 版本不支持 torch.cuda.graph,使用简化验证方法" )
@@ -119,69 +117,71 @@ def test_embedding_graph_recording():
119117 else :
120118 print (f" ⚠ 图录制测试异常: { e } " )
121119 return test_embedding_async_verification (embedding , input_ids_device )
122-
120+
123121 except Exception as e :
124122 print (f" ⚠ 图录制测试异常: { e } " )
125123 print (" 使用简化验证方法..." )
126124 import traceback
125+
127126 traceback .print_exc ()
128127 return test_embedding_async_verification (embedding , input_ids_device )
129128
130129
131130def test_embedding_async_verification (embedding , input_ids_device ):
132131 """
133132 简化验证:检查是否有同步操作
134-
133+
135134 关键检查点:
136135 1. 输入是否可以在设备上(改动前需要 CPU,改动后支持设备)
137136 2. 操作是否完全异步(没有同步点)
138137 """
139138 print ("\n 3. 简化验证:检查异步操作支持" )
140-
139+
141140 # 验证 1: 输入可以在设备上
142141 if input_ids_device .device .type != "cuda" :
143142 print (" ✗ 输入不在设备上,无法验证" )
144143 return False
145-
144+
146145 print (" ✓ 输入在设备上" )
147-
146+
148147 # 验证 2: 执行 forward,检查是否有同步操作
149148 # 如果改动前,这里会调用 indices->to(cpu_device),触发同步
150149 # 如果改动后,直接使用设备端 kernel,完全异步
151-
150+
152151 try :
153152 # 记录开始时间
154153 start_event = infinicore .DeviceEvent (enable_timing = True )
155154 end_event = infinicore .DeviceEvent (enable_timing = True )
156-
155+
157156 start_event .record ()
158157 output = embedding .forward (input_ids_device )
159158 end_event .record ()
160-
159+
161160 # 不立即同步,检查操作是否异步
162161 # 如果操作是异步的,query 应该返回 False(未完成)
163162 # 如果操作是同步的,可能已经完成
164-
163+
165164 # 等待一小段时间
166165 import time
166+
167167 time .sleep (0.001 ) # 1ms
168-
168+
169169 # 检查事件状态
170170 is_complete = end_event .query ()
171-
171+
172172 if not is_complete :
173173 print (" ✓ 操作是异步的(事件未立即完成)" )
174174 else :
175175 print (" ⚠ 操作可能包含同步点(事件立即完成)" )
176-
176+
177177 # 同步并测量时间
178178 end_event .synchronize ()
179179 elapsed = start_event .elapsed_time (end_event )
180-
180+
181181 print (f" ✓ Forward 执行时间: { elapsed :.3f} ms" )
182182 print (f" ✓ 输出形状: { output .shape } " )
183183 print (f" ✓ 输出设备: { output .device .type } " )
184-
184+
185185 # 验证输出正确性
186186 embedding_dim = embedding .embedding_dim ()
187187 expected_shape = (* input_ids_device .shape , embedding_dim )
@@ -193,10 +193,11 @@ def test_embedding_async_verification(embedding, input_ids_device):
193193 print (f" 期望形状: { expected_shape } , 实际形状: { output .shape } " )
194194 print (f" 期望设备: cuda, 实际设备: { output .device .type } " )
195195 return False
196-
196+
197197 except Exception as e :
198198 print (f" ✗ 验证失败: { e } " )
199199 import traceback
200+
200201 traceback .print_exc ()
201202 return False
202203
@@ -206,29 +207,27 @@ def test_embedding_device_input_support():
206207 print ("\n " + "=" * 60 )
207208 print ("测试 Embedding 设备端输入支持" )
208209 print ("=" * 60 )
209-
210+
210211 if not torch .cuda .is_available ():
211212 print ("⚠ CUDA 不可用,跳过测试" )
212213 return False
213-
214+
214215 device = infinicore .device ("cuda" , 0 )
215216 vocab_size = 100
216217 embedding_dim = 64
217-
218+
218219 embedding = infinicore .nn .Embedding (
219220 num_embeddings = vocab_size ,
220221 embedding_dim = embedding_dim ,
221222 dtype = infinicore .float32 ,
222- device = device
223+ device = device ,
223224 )
224-
225+
225226 # 测试 1: 设备端输入(改动后支持)
226227 print ("\n 测试 1: 设备端输入" )
227228 try :
228229 input_ids_device = infinicore .from_list (
229- [[1 , 2 , 3 , 4 , 5 ]],
230- dtype = infinicore .int64 ,
231- device = device
230+ [[1 , 2 , 3 , 4 , 5 ]], dtype = infinicore .int64 , device = device
232231 )
233232 output = embedding .forward (input_ids_device )
234233 print (f" ✓ 设备端输入成功" )
@@ -246,36 +245,36 @@ def main():
246245 print ("\n " + "=" * 60 )
247246 print ("Embedding 图录制支持验证" )
248247 print ("=" * 60 )
249-
248+
250249 results = []
251-
250+
252251 # 测试 1: 图录制支持
253252 result1 = test_embedding_graph_recording ()
254253 results .append (("CUDA Graph 录制" , result1 ))
255-
254+
256255 # 测试 2: 设备端输入支持
257256 result2 = test_embedding_device_input_support ()
258257 results .append (("设备端输入" , result2 ))
259-
258+
260259 # 总结
261260 print ("\n " + "=" * 60 )
262261 print ("测试结果总结" )
263262 print ("=" * 60 )
264-
263+
265264 all_passed = True
266265 for test_name , result in results :
267266 status = "✓ 通过" if result else "✗ 失败"
268267 print (f"{ test_name } : { status } " )
269268 if not result :
270269 all_passed = False
271-
270+
272271 print ("\n " + "=" * 60 )
273272 if all_passed :
274273 print ("✓ 所有测试通过!Embedding 支持图录制" )
275274 else :
276275 print ("✗ 部分测试失败,Embedding 可能不完全支持图录制" )
277276 print ("=" * 60 )
278-
277+
279278 return all_passed
280279
281280
0 commit comments