Skip to content

Commit 7997d89

Browse files
committed
issue/900 - adjust doc text and test script
1 parent bea5068 commit 7997d89

File tree

2 files changed

+55
-56
lines changed

2 files changed

+55
-56
lines changed

test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
### 运行测试
66

77
```bash
8-
cd /home/zhuyue/codes/InfiniCore
8+
cd <path_to>/InfiniCore
99
python test/infinicore/nn/test_embedding_graph_recording.py
1010
```
1111

@@ -259,7 +259,7 @@ python test/infinicore/nn/test_embedding_graph_recording.py
259259
#!/bin/bash
260260
# quick_check.sh
261261

262-
cd /home/zhuyue/codes/InfiniCore
262+
cd <path_to>/InfiniCore
263263

264264
echo "=== 1. 代码检查 ==="
265265
if grep -q "to(cpu_device)" src/infinicore/nn/embedding.cc; then

test/infinicore/nn/test_embedding_graph_recording.py

Lines changed: 53 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,97 +15,95 @@
1515

1616
import infinicore
1717
import torch
18-
import ctypes
1918

2019

2120
def test_embedding_graph_recording():
2221
"""测试 embedding 是否支持 CUDA Graph 录制"""
2322
print("=" * 60)
2423
print("测试 Embedding 图录制支持")
2524
print("=" * 60)
26-
25+
2726
# 检查是否有 CUDA
2827
if not torch.cuda.is_available():
2928
print("⚠ CUDA 不可用,跳过图录制测试")
3029
return False
31-
30+
3231
device = infinicore.device("cuda", 0)
33-
32+
3433
# 创建 embedding 模块
3534
vocab_size = 1000
3635
embedding_dim = 128
3736
embedding = infinicore.nn.Embedding(
3837
num_embeddings=vocab_size,
3938
embedding_dim=embedding_dim,
4039
dtype=infinicore.float32,
41-
device=device
40+
device=device,
4241
)
43-
42+
4443
# 创建设备端的 input_ids(这是关键:改动前不支持,改动后支持)
4544
batch_size = 4
4645
seq_len = 32
4746
input_ids_device = infinicore.from_list(
4847
[[i % vocab_size for i in range(seq_len)] for _ in range(batch_size)],
4948
dtype=infinicore.int64,
50-
device=device
49+
device=device,
5150
)
52-
51+
5352
print(f"\n1. 输入张量信息:")
5453
print(f" - Shape: {input_ids_device.shape}")
5554
print(f" - Device: {input_ids_device.device.type}")
5655
print(f" - Dtype: {input_ids_device.dtype}")
57-
56+
5857
# 尝试使用 CUDA Graph 录制
5958
print(f"\n2. 尝试 CUDA Graph 录制...")
60-
59+
6160
# 使用 PyTorch 的 CUDA Graph API 进行测试(更简单可靠)
6261
try:
6362
# 设置设备
6463
infinicore.set_device(device)
65-
64+
6665
# 使用 PyTorch 的 CUDA Graph API
6766
# 注意:PyTorch 2.0+ 支持 torch.cuda.graph
6867
try:
6968
# 方法 1: 使用 PyTorch 的 CUDA Graph(推荐)
7069
print(" 使用 PyTorch CUDA Graph API 测试...")
71-
70+
7271
# 创建 warmup 输入
7372
warmup_input = input_ids_device
74-
73+
7574
# Warmup(图录制前需要先执行一次,包括内存分配)
76-
warmup_output = embedding.forward(warmup_input)
75+
embedding.forward(warmup_input)
7776
infinicore.sync_stream() # 同步确保 warmup 完成
78-
77+
7978
# 预先分配输出张量(CUDA Graph 不支持动态内存分配)
8079
# 输出形状: input_shape + [embedding_dim]
8180
output_shape = list(input_ids_device.shape) + [embedding_dim]
8281
output = infinicore.empty(
83-
output_shape,
84-
dtype=embedding.weight.dtype,
85-
device=device
82+
output_shape, dtype=embedding.weight.dtype, device=device
8683
)
87-
84+
8885
# Warmup embedding(确保内存分配完成)
8986
import infinicore.nn.functional as F
87+
9088
F.embedding(warmup_input, embedding.weight, out=output)
9189
infinicore.sync_stream()
92-
90+
9391
# 开始图录制(使用预先分配的 output)
9492
graph = torch.cuda.CUDAGraph()
9593
with torch.cuda.graph(graph):
9694
# 使用 embedding 的 out 参数(in-place),传入预先分配的 output
9795
F.embedding(input_ids_device, embedding.weight, out=output)
98-
96+
9997
print(" ✓ 成功完成图录制!")
10098
print(" ✓ Embedding 支持 CUDA Graph 录制")
101-
99+
102100
# 验证图可以重复执行
103101
graph.replay()
104102
infinicore.sync_stream()
105-
103+
106104
print(" ✓ 图可以成功重放")
107105
return True
108-
106+
109107
except AttributeError:
110108
# PyTorch 版本可能不支持 torch.cuda.graph
111109
print(" ⚠ PyTorch 版本不支持 torch.cuda.graph,使用简化验证方法")
@@ -119,69 +117,71 @@ def test_embedding_graph_recording():
119117
else:
120118
print(f" ⚠ 图录制测试异常: {e}")
121119
return test_embedding_async_verification(embedding, input_ids_device)
122-
120+
123121
except Exception as e:
124122
print(f" ⚠ 图录制测试异常: {e}")
125123
print(" 使用简化验证方法...")
126124
import traceback
125+
127126
traceback.print_exc()
128127
return test_embedding_async_verification(embedding, input_ids_device)
129128

130129

131130
def test_embedding_async_verification(embedding, input_ids_device):
132131
"""
133132
简化验证:检查是否有同步操作
134-
133+
135134
关键检查点:
136135
1. 输入是否可以在设备上(改动前需要 CPU,改动后支持设备)
137136
2. 操作是否完全异步(没有同步点)
138137
"""
139138
print("\n3. 简化验证:检查异步操作支持")
140-
139+
141140
# 验证 1: 输入可以在设备上
142141
if input_ids_device.device.type != "cuda":
143142
print(" ✗ 输入不在设备上,无法验证")
144143
return False
145-
144+
146145
print(" ✓ 输入在设备上")
147-
146+
148147
# 验证 2: 执行 forward,检查是否有同步操作
149148
# 如果改动前,这里会调用 indices->to(cpu_device),触发同步
150149
# 如果改动后,直接使用设备端 kernel,完全异步
151-
150+
152151
try:
153152
# 记录开始时间
154153
start_event = infinicore.DeviceEvent(enable_timing=True)
155154
end_event = infinicore.DeviceEvent(enable_timing=True)
156-
155+
157156
start_event.record()
158157
output = embedding.forward(input_ids_device)
159158
end_event.record()
160-
159+
161160
# 不立即同步,检查操作是否异步
162161
# 如果操作是异步的,query 应该返回 False(未完成)
163162
# 如果操作是同步的,可能已经完成
164-
163+
165164
# 等待一小段时间
166165
import time
166+
167167
time.sleep(0.001) # 1ms
168-
168+
169169
# 检查事件状态
170170
is_complete = end_event.query()
171-
171+
172172
if not is_complete:
173173
print(" ✓ 操作是异步的(事件未立即完成)")
174174
else:
175175
print(" ⚠ 操作可能包含同步点(事件立即完成)")
176-
176+
177177
# 同步并测量时间
178178
end_event.synchronize()
179179
elapsed = start_event.elapsed_time(end_event)
180-
180+
181181
print(f" ✓ Forward 执行时间: {elapsed:.3f} ms")
182182
print(f" ✓ 输出形状: {output.shape}")
183183
print(f" ✓ 输出设备: {output.device.type}")
184-
184+
185185
# 验证输出正确性
186186
embedding_dim = embedding.embedding_dim()
187187
expected_shape = (*input_ids_device.shape, embedding_dim)
@@ -193,10 +193,11 @@ def test_embedding_async_verification(embedding, input_ids_device):
193193
print(f" 期望形状: {expected_shape}, 实际形状: {output.shape}")
194194
print(f" 期望设备: cuda, 实际设备: {output.device.type}")
195195
return False
196-
196+
197197
except Exception as e:
198198
print(f" ✗ 验证失败: {e}")
199199
import traceback
200+
200201
traceback.print_exc()
201202
return False
202203

@@ -206,29 +207,27 @@ def test_embedding_device_input_support():
206207
print("\n" + "=" * 60)
207208
print("测试 Embedding 设备端输入支持")
208209
print("=" * 60)
209-
210+
210211
if not torch.cuda.is_available():
211212
print("⚠ CUDA 不可用,跳过测试")
212213
return False
213-
214+
214215
device = infinicore.device("cuda", 0)
215216
vocab_size = 100
216217
embedding_dim = 64
217-
218+
218219
embedding = infinicore.nn.Embedding(
219220
num_embeddings=vocab_size,
220221
embedding_dim=embedding_dim,
221222
dtype=infinicore.float32,
222-
device=device
223+
device=device,
223224
)
224-
225+
225226
# 测试 1: 设备端输入(改动后支持)
226227
print("\n测试 1: 设备端输入")
227228
try:
228229
input_ids_device = infinicore.from_list(
229-
[[1, 2, 3, 4, 5]],
230-
dtype=infinicore.int64,
231-
device=device
230+
[[1, 2, 3, 4, 5]], dtype=infinicore.int64, device=device
232231
)
233232
output = embedding.forward(input_ids_device)
234233
print(f" ✓ 设备端输入成功")
@@ -246,36 +245,36 @@ def main():
246245
print("\n" + "=" * 60)
247246
print("Embedding 图录制支持验证")
248247
print("=" * 60)
249-
248+
250249
results = []
251-
250+
252251
# 测试 1: 图录制支持
253252
result1 = test_embedding_graph_recording()
254253
results.append(("CUDA Graph 录制", result1))
255-
254+
256255
# 测试 2: 设备端输入支持
257256
result2 = test_embedding_device_input_support()
258257
results.append(("设备端输入", result2))
259-
258+
260259
# 总结
261260
print("\n" + "=" * 60)
262261
print("测试结果总结")
263262
print("=" * 60)
264-
263+
265264
all_passed = True
266265
for test_name, result in results:
267266
status = "✓ 通过" if result else "✗ 失败"
268267
print(f"{test_name}: {status}")
269268
if not result:
270269
all_passed = False
271-
270+
272271
print("\n" + "=" * 60)
273272
if all_passed:
274273
print("✓ 所有测试通过!Embedding 支持图录制")
275274
else:
276275
print("✗ 部分测试失败,Embedding 可能不完全支持图录制")
277276
print("=" * 60)
278-
277+
279278
return all_passed
280279

281280

0 commit comments

Comments
 (0)