Skip to content

Commit 931bcdc

Browse files
committed
Update test configurations for CUDA forward equivalence to include additional head dimensions and adjust keep_window_size
1 parent 880b6e3 commit 931bcdc

File tree

1 file changed

+51
-39
lines changed

1 file changed

+51
-39
lines changed

benchmarks/forward_equivalence.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -515,76 +515,88 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
515515
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
516516
test_configs = [
517517
# Head dim 32
518-
(1, 2, 1, 128, 128, 32, True),
519518
(1, 2, 1, 128, 128, 32, False),
520-
(1, 2, 1, 256, 256, 32, True),
519+
(1, 2, 1, 128, 128, 32, True),
521520
(1, 2, 1, 256, 256, 32, False),
522-
(1, 2, 1, 512, 512, 32, True),
521+
(1, 2, 1, 256, 256, 32, True),
523522
(1, 2, 1, 512, 512, 32, False),
524-
(1, 2, 1, 1024, 1024, 32, True),
523+
(1, 2, 1, 512, 512, 32, True),
525524
(1, 2, 1, 1024, 1024, 32, False),
526-
(1, 2, 1, 2048, 2048, 32, True),
525+
(1, 2, 1, 1024, 1024, 32, True),
527526
(1, 2, 1, 2048, 2048, 32, False),
528-
(1, 2, 1, 4096, 4096, 32, True),
527+
(1, 2, 1, 2048, 2048, 32, True),
529528
(1, 2, 1, 4096, 4096, 32, False),
529+
(1, 2, 1, 4096, 4096, 32, True),
530530

531531
# Head dim 64
532-
(1, 2, 1, 128, 128, 64, True),
533532
(1, 2, 1, 128, 128, 64, False),
534-
(1, 2, 1, 256, 256, 64, True),
533+
(1, 2, 1, 128, 128, 64, True),
535534
(1, 2, 1, 256, 256, 64, False),
536-
(1, 2, 1, 512, 512, 64, True),
535+
(1, 2, 1, 256, 256, 64, True),
537536
(1, 2, 1, 512, 512, 64, False),
538-
(1, 2, 1, 1024, 1024, 64, True),
537+
(1, 2, 1, 512, 512, 64, True),
539538
(1, 2, 1, 1024, 1024, 64, False),
540-
(1, 2, 1, 2048, 2048, 64, True),
539+
(1, 2, 1, 1024, 1024, 64, True),
541540
(1, 2, 1, 2048, 2048, 64, False),
542-
(1, 2, 1, 4096, 4096, 64, True),
541+
(1, 2, 1, 2048, 2048, 64, True),
543542
(1, 2, 1, 4096, 4096, 64, False),
543+
(1, 2, 1, 4096, 4096, 64, True),
544544

545545
# Head dim 96
546-
(1, 2, 1, 128, 128, 96, True),
547546
(1, 2, 1, 128, 128, 96, False),
548-
(1, 2, 1, 256, 256, 96, True),
547+
(1, 2, 1, 128, 128, 96, True),
549548
(1, 2, 1, 256, 256, 96, False),
550-
(1, 2, 1, 512, 512, 96, True),
549+
(1, 2, 1, 256, 256, 96, True),
551550
(1, 2, 1, 512, 512, 96, False),
552-
(1, 2, 1, 1024, 1024, 96, True),
551+
(1, 2, 1, 512, 512, 96, True),
553552
(1, 2, 1, 1024, 1024, 96, False),
554-
(1, 2, 1, 2048, 2048, 96, True),
553+
(1, 2, 1, 1024, 1024, 96, True),
555554
(1, 2, 1, 2048, 2048, 96, False),
556-
(1, 2, 1, 4096, 4096, 96, True),
555+
(1, 2, 1, 2048, 2048, 96, True),
557556
(1, 2, 1, 4096, 4096, 96, False),
557+
(1, 2, 1, 4096, 4096, 96, True),
558558

559559
# Head dim 128
560-
(1, 2, 1, 128, 128, 128, True),
561560
(1, 2, 1, 128, 128, 128, False),
562-
(1, 2, 1, 256, 256, 128, True),
561+
(1, 2, 1, 128, 128, 128, True),
563562
(1, 2, 1, 256, 256, 128, False),
564-
(1, 2, 1, 512, 512, 128, True),
563+
(1, 2, 1, 256, 256, 128, True),
565564
(1, 2, 1, 512, 512, 128, False),
566-
(1, 2, 1, 1024, 1024, 128, True),
565+
(1, 2, 1, 512, 512, 128, True),
567566
(1, 2, 1, 1024, 1024, 128, False),
568-
(1, 2, 1, 2048, 2048, 128, True),
567+
(1, 2, 1, 1024, 1024, 128, True),
569568
(1, 2, 1, 2048, 2048, 128, False),
570-
(1, 2, 1, 4096, 4096, 128, True),
569+
(1, 2, 1, 2048, 2048, 128, True),
571570
(1, 2, 1, 4096, 4096, 128, False),
571+
(1, 2, 1, 4096, 4096, 128, True),
572+
573+
# Head dim 192
574+
(1, 2, 1, 128, 128, 192, False),
575+
(1, 2, 1, 128, 128, 192, True),
576+
(1, 2, 1, 256, 256, 192, False),
577+
(1, 2, 1, 256, 256, 192, True),
578+
(1, 2, 1, 512, 512, 192, False),
579+
(1, 2, 1, 512, 512, 192, True),
580+
(1, 2, 1, 1024, 1024, 192, False),
581+
(1, 2, 1, 1024, 1024, 192, True),
582+
(1, 2, 1, 2048, 2048, 192, False),
583+
(1, 2, 1, 2048, 2048, 192, True),
584+
(1, 2, 1, 4096, 4096, 192, False),
585+
(1, 2, 1, 4096, 4096, 192, True),
572586

573-
# Not support head_dim = 256 in sm89 yet
574-
# Because fwd uses splitkv branch by default, and shared memory is not enough for sm89
575587
# Head dim 256
576-
# (1, 2, 1, 128, 128, 256, True),
577-
# (1, 2, 1, 128, 128, 256, False),
578-
# (1, 2, 1, 256, 256, 256, True),
579-
# (1, 2, 1, 256, 256, 256, False),
580-
# (1, 2, 1, 512, 512, 256, True),
581-
# (1, 2, 1, 512, 512, 256, False),
582-
# (1, 2, 1, 1024, 1024, 256, True),
583-
# (1, 2, 1, 1024, 1024, 256, False),
584-
# (1, 2, 1, 2048, 2048, 256, True),
585-
# (1, 2, 1, 2048, 2048, 256, False),
586-
# (1, 2, 1, 4096, 4096, 256, True),
587-
# (1, 2, 1, 4096, 4096, 256, False),
588+
(1, 2, 1, 128, 128, 256, False),
589+
(1, 2, 1, 128, 128, 256, True),
590+
(1, 2, 1, 256, 256, 256, False),
591+
(1, 2, 1, 256, 256, 256, True),
592+
(1, 2, 1, 512, 512, 256, False),
593+
(1, 2, 1, 512, 512, 256, True),
594+
(1, 2, 1, 1024, 1024, 256, False),
595+
(1, 2, 1, 1024, 1024, 256, True),
596+
(1, 2, 1, 2048, 2048, 256, False),
597+
(1, 2, 1, 2048, 2048, 256, True),
598+
(1, 2, 1, 4096, 4096, 256, False),
599+
(1, 2, 1, 4096, 4096, 256, True),
588600
]
589601

590602
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -635,7 +647,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
635647

636648
# Set scaling factor and keep window size
637649
scaling = head_dim ** -0.5
638-
keep_window_size = 64
650+
keep_window_size = 1024
639651

640652
# Run Python implementation
641653
start_time = time.time()

0 commit comments

Comments
 (0)