@@ -515,76 +515,88 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
515515 # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
516516 test_configs = [
517517 # Head dim 32
518- (1 , 2 , 1 , 128 , 128 , 32 , True ),
519518 (1 , 2 , 1 , 128 , 128 , 32 , False ),
520- (1 , 2 , 1 , 256 , 256 , 32 , True ),
519+ (1 , 2 , 1 , 128 , 128 , 32 , True ),
521520 (1 , 2 , 1 , 256 , 256 , 32 , False ),
522- (1 , 2 , 1 , 512 , 512 , 32 , True ),
521+ (1 , 2 , 1 , 256 , 256 , 32 , True ),
523522 (1 , 2 , 1 , 512 , 512 , 32 , False ),
524- (1 , 2 , 1 , 1024 , 1024 , 32 , True ),
523+ (1 , 2 , 1 , 512 , 512 , 32 , True ),
525524 (1 , 2 , 1 , 1024 , 1024 , 32 , False ),
526- (1 , 2 , 1 , 2048 , 2048 , 32 , True ),
525+ (1 , 2 , 1 , 1024 , 1024 , 32 , True ),
527526 (1 , 2 , 1 , 2048 , 2048 , 32 , False ),
528- (1 , 2 , 1 , 4096 , 4096 , 32 , True ),
527+ (1 , 2 , 1 , 2048 , 2048 , 32 , True ),
529528 (1 , 2 , 1 , 4096 , 4096 , 32 , False ),
529+ (1 , 2 , 1 , 4096 , 4096 , 32 , True ),
530530
531531 # Head dim 64
532- (1 , 2 , 1 , 128 , 128 , 64 , True ),
533532 (1 , 2 , 1 , 128 , 128 , 64 , False ),
534- (1 , 2 , 1 , 256 , 256 , 64 , True ),
533+ (1 , 2 , 1 , 128 , 128 , 64 , True ),
535534 (1 , 2 , 1 , 256 , 256 , 64 , False ),
536- (1 , 2 , 1 , 512 , 512 , 64 , True ),
535+ (1 , 2 , 1 , 256 , 256 , 64 , True ),
537536 (1 , 2 , 1 , 512 , 512 , 64 , False ),
538- (1 , 2 , 1 , 1024 , 1024 , 64 , True ),
537+ (1 , 2 , 1 , 512 , 512 , 64 , True ),
539538 (1 , 2 , 1 , 1024 , 1024 , 64 , False ),
540- (1 , 2 , 1 , 2048 , 2048 , 64 , True ),
539+ (1 , 2 , 1 , 1024 , 1024 , 64 , True ),
541540 (1 , 2 , 1 , 2048 , 2048 , 64 , False ),
542- (1 , 2 , 1 , 4096 , 4096 , 64 , True ),
541+ (1 , 2 , 1 , 2048 , 2048 , 64 , True ),
543542 (1 , 2 , 1 , 4096 , 4096 , 64 , False ),
543+ (1 , 2 , 1 , 4096 , 4096 , 64 , True ),
544544
545545 # Head dim 96
546- (1 , 2 , 1 , 128 , 128 , 96 , True ),
547546 (1 , 2 , 1 , 128 , 128 , 96 , False ),
548- (1 , 2 , 1 , 256 , 256 , 96 , True ),
547+ (1 , 2 , 1 , 128 , 128 , 96 , True ),
549548 (1 , 2 , 1 , 256 , 256 , 96 , False ),
550- (1 , 2 , 1 , 512 , 512 , 96 , True ),
549+ (1 , 2 , 1 , 256 , 256 , 96 , True ),
551550 (1 , 2 , 1 , 512 , 512 , 96 , False ),
552- (1 , 2 , 1 , 1024 , 1024 , 96 , True ),
551+ (1 , 2 , 1 , 512 , 512 , 96 , True ),
553552 (1 , 2 , 1 , 1024 , 1024 , 96 , False ),
554- (1 , 2 , 1 , 2048 , 2048 , 96 , True ),
553+ (1 , 2 , 1 , 1024 , 1024 , 96 , True ),
555554 (1 , 2 , 1 , 2048 , 2048 , 96 , False ),
556- (1 , 2 , 1 , 4096 , 4096 , 96 , True ),
555+ (1 , 2 , 1 , 2048 , 2048 , 96 , True ),
557556 (1 , 2 , 1 , 4096 , 4096 , 96 , False ),
557+ (1 , 2 , 1 , 4096 , 4096 , 96 , True ),
558558
559559 # Head dim 128
560- (1 , 2 , 1 , 128 , 128 , 128 , True ),
561560 (1 , 2 , 1 , 128 , 128 , 128 , False ),
562- (1 , 2 , 1 , 256 , 256 , 128 , True ),
561+ (1 , 2 , 1 , 128 , 128 , 128 , True ),
563562 (1 , 2 , 1 , 256 , 256 , 128 , False ),
564- (1 , 2 , 1 , 512 , 512 , 128 , True ),
563+ (1 , 2 , 1 , 256 , 256 , 128 , True ),
565564 (1 , 2 , 1 , 512 , 512 , 128 , False ),
566- (1 , 2 , 1 , 1024 , 1024 , 128 , True ),
565+ (1 , 2 , 1 , 512 , 512 , 128 , True ),
567566 (1 , 2 , 1 , 1024 , 1024 , 128 , False ),
568- (1 , 2 , 1 , 2048 , 2048 , 128 , True ),
567+ (1 , 2 , 1 , 1024 , 1024 , 128 , True ),
569568 (1 , 2 , 1 , 2048 , 2048 , 128 , False ),
570- (1 , 2 , 1 , 4096 , 4096 , 128 , True ),
569+ (1 , 2 , 1 , 2048 , 2048 , 128 , True ),
571570 (1 , 2 , 1 , 4096 , 4096 , 128 , False ),
571+ (1 , 2 , 1 , 4096 , 4096 , 128 , True ),
572+
573+ # Head dim 192
574+ (1 , 2 , 1 , 128 , 128 , 192 , False ),
575+ (1 , 2 , 1 , 128 , 128 , 192 , True ),
576+ (1 , 2 , 1 , 256 , 256 , 192 , False ),
577+ (1 , 2 , 1 , 256 , 256 , 192 , True ),
578+ (1 , 2 , 1 , 512 , 512 , 192 , False ),
579+ (1 , 2 , 1 , 512 , 512 , 192 , True ),
580+ (1 , 2 , 1 , 1024 , 1024 , 192 , False ),
581+ (1 , 2 , 1 , 1024 , 1024 , 192 , True ),
582+ (1 , 2 , 1 , 2048 , 2048 , 192 , False ),
583+ (1 , 2 , 1 , 2048 , 2048 , 192 , True ),
584+ (1 , 2 , 1 , 4096 , 4096 , 192 , False ),
585+ (1 , 2 , 1 , 4096 , 4096 , 192 , True ),
572586
573- # Not support head_dim = 256 in sm89 yet
574- # Because fwd uses splitkv branch by default, and shared memory is not enough for sm89
575587 # Head dim 256
576- # (1, 2, 1, 128, 128, 256, True ),
577- # (1, 2, 1, 128, 128, 256, False ),
578- # (1, 2, 1, 256, 256, 256, True ),
579- # (1, 2, 1, 256, 256, 256, False ),
580- # (1, 2, 1, 512, 512, 256, True ),
581- # (1, 2, 1, 512, 512, 256, False ),
582- # (1, 2, 1, 1024, 1024, 256, True ),
583- # (1, 2, 1, 1024, 1024, 256, False ),
584- # (1, 2, 1, 2048, 2048, 256, True ),
585- # (1, 2, 1, 2048, 2048, 256, False ),
586- # (1, 2, 1, 4096, 4096, 256, True ),
587- # (1, 2, 1, 4096, 4096, 256, False ),
588+ (1 , 2 , 1 , 128 , 128 , 256 , False ),
589+ (1 , 2 , 1 , 128 , 128 , 256 , True ),
590+ (1 , 2 , 1 , 256 , 256 , 256 , False ),
591+ (1 , 2 , 1 , 256 , 256 , 256 , True ),
592+ (1 , 2 , 1 , 512 , 512 , 256 , False ),
593+ (1 , 2 , 1 , 512 , 512 , 256 , True ),
594+ (1 , 2 , 1 , 1024 , 1024 , 256 , False ),
595+ (1 , 2 , 1 , 1024 , 1024 , 256 , True ),
596+ (1 , 2 , 1 , 2048 , 2048 , 256 , False ),
597+ (1 , 2 , 1 , 2048 , 2048 , 256 , True ),
598+ (1 , 2 , 1 , 4096 , 4096 , 256 , False ),
599+ (1 , 2 , 1 , 4096 , 4096 , 256 , True ),
588600 ]
589601
590602 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -635,7 +647,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
635647
636648 # Set scaling factor and keep window size
637649 scaling = head_dim ** - 0.5
638- keep_window_size = 64
650+ keep_window_size = 1024
639651
640652 # Run Python implementation
641653 start_time = time .time ()
0 commit comments