File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed
tests/unittest/_torch/attention/sparse Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -39,7 +39,7 @@ FetchContent_Declare(
3939FetchContent_Declare (
4040 deepgemm
4141 GIT_REPOSITORY https://github.com/ruoqianguo/DeepGEMM
42- GIT_TAG 9fa5965e265e27995f539e0dd73a06351a8a9eaf
42+ GIT_TAG 6cb8161516302550785d9af924d2778afef1f3f6 # swapab_sm100 branch
4343 GIT_SUBMODULES_RECURSE
4444 ON
4545 SOURCE_SUBDIR
Original file line number Diff line number Diff line change @@ -308,9 +308,9 @@ def test_deepgemm_fp8_mqa_logits_basic():
308308 """
309309 torch .manual_seed (0 )
310310
311- num_heads , head_dim = 32 , 128
312- seq_len = 512
313- seq_len_kv = 1024
311+ num_heads , head_dim = 64 , 128
312+ seq_len = 2048
313+ seq_len_kv = 4096
314314 #[seq_len, num_heads, head_dim]
315315 q = torch .randn (
316316 seq_len ,
@@ -335,8 +335,8 @@ def test_deepgemm_fp8_mqa_logits_basic():
335335 )
336336 # ks[i] -> ke[i] for each q[i]
337337 ks = torch .zeros (seq_len , dtype = torch .int , device = "cuda" )
338- ke = torch .arange (seq_len , dtype = torch .int , device = "cuda" ) + (
339- seq_len_kv - seq_len ) + 1 # +1 for exclusive end
338+ ke = torch .arange (seq_len , dtype = torch .int ,
339+ device = "cuda" ) + ( seq_len_kv - seq_len )
340340
341341 # Convert to FP8
342342 q_fp8 = q .to (torch .float8_e4m3fn )
You can’t perform that action at this time.
0 commit comments