Skip to content

Commit a32563a

Browse files
author
wangzaijun
committed
fix
1 parent 0f7dedb commit a32563a

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

test/kernel/llama_gqa_decode_vsm_tuning.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,14 @@ def inner_alloc_func(shape, dtype=torch.float32, device="cuda"):
9393

9494
graph.replay()
9595

96-
torch.cuda.synchronize()
97-
start = time.time()
98-
# graph.replay()
99-
torch.cuda.synchronize()
96+
start_event = torch.cuda.Event(enable_timing=True)
97+
end_event = torch.cuda.Event(enable_timing=True)
98+
start_event.record()
99+
graph.replay()
100+
end_event.record()
101+
end_event.synchronize()
100102

101-
cost_time = (time.time() - start) * 1000
103+
cost_time = start_event.elapsed_time(end_event=end_event)
102104

103105
logger.info(f"fp16 {test_seq_len} cost time: {cost_time} ms")
104106
return cost_time

test/kernel/llama_gqa_diverse_decode_stage1_tuning.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,14 @@ def test_decode_attentions(
122122

123123
graph.replay()
124124

125-
torch.cuda.synchronize()
126-
start = time.time()
127-
# graph.replay()
128-
torch.cuda.synchronize()
125+
start_event = torch.cuda.Event(enable_timing=True)
126+
end_event = torch.cuda.Event(enable_timing=True)
127+
start_event.record()
128+
graph.replay()
129+
end_event.record()
130+
end_event.synchronize()
129131

130-
cost_time = (time.time() - start) * 1000
132+
cost_time = start_event.elapsed_time(end_event=end_event)
131133

132134
logger.info(f"fp16 {test_seq_len} cost time: {cost_time} ms")
133135
return cost_time

unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def setup_tensors():
1515
max_batch_group_size = 4
1616
quant_group_size = 8
1717

18-
test_dtype = torch.float32
18+
test_dtype = torch.bfloat16
1919

2020
kv_shape = (batch_size * seq_len, kv_head_num, head_dim)
2121
kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size)

0 commit comments

Comments
 (0)