21
21
from .test_utils import attention_ref , generate_qkv , generate_random_padding_mask
22
22
23
23
common_settings = {
24
- "verbosity" : Verbosity .verbose ,
25
- "max_examples" : 20 ,
24
+ "verbosity" : Verbosity .normal ,
25
+ "max_examples" : 200 ,
26
26
"deadline" : None ,
27
27
"suppress_health_check" : [HealthCheck .filter_too_much , HealthCheck .data_too_large ],
28
28
}
29
29
30
30
DEBUG = False
31
+ SEED = 2
31
32
32
33
compute_capability = (0 , 0 )
33
34
if torch .cuda .is_available ():
@@ -50,21 +51,39 @@ def _allclose(
50
51
t_pt : torch .Tensor ,
51
52
) -> None :
52
53
assert t_test .shape == t_ref .shape == t_pt .shape
54
+
55
+ ratio = 2.0
56
+
57
+ # Calculate all differences
58
+ test_ref_diff = self ._abs_max (t_test - t_ref )
59
+ test_pt_diff = self ._abs_max (t_test - t_pt )
60
+ pt_ref_diff = self ._abs_max (t_pt - t_ref )
61
+
53
62
if DEBUG :
54
63
# Debug: Print the differences
64
+ print (f"DEBUG: Max absolute difference vs ref: { test_ref_diff } " )
65
+ print (f"DEBUG: Max absolute difference vs pt: { test_pt_diff } " )
66
+ print (f"DEBUG: Max absolute difference pt vs ref: { pt_ref_diff } " )
55
67
print (
56
- f"DEBUG: Max absolute difference vs ref: { self ._abs_max (t_test - t_ref )} "
57
- )
58
- print (
59
- f"DEBUG: Max absolute difference vs pt: { self ._abs_max (t_test - t_pt )} "
60
- )
61
- print (
62
- f"DEBUG: Max absolute difference pt vs ref: { self ._abs_max (t_pt - t_ref )} "
63
- )
64
- print (
65
- f"DEBUG: Tolerance check: { self ._abs_max (t_test - t_ref )} <= { 2 * self ._abs_max (t_pt - t_ref ) + 1e-5 } "
68
+ f"DEBUG: Tolerance check: { test_ref_diff } <= { ratio * pt_ref_diff + 1e-5 } "
66
69
)
67
- assert self ._abs_max (t_test - t_ref ) <= 2 * self ._abs_max (t_pt - t_ref ) + 1e-4
70
+
71
+ # First assertion with gap information
72
+ tolerance_threshold = ratio * pt_ref_diff + 1e-4
73
+ assert test_ref_diff <= tolerance_threshold , (
74
+ f"Tolerance check failed: max_diff={ test_ref_diff :.6f} > "
75
+ f"threshold={ tolerance_threshold :.6f} , gap={ test_ref_diff - tolerance_threshold :.6f} "
76
+ )
77
+
78
+ # sanity checks
79
+ assert test_ref_diff <= 0.5 , (
80
+ f"Max difference vs ref too large: { test_ref_diff :.6f} > 0.5, "
81
+ f"gap={ test_ref_diff - 0.5 :.6f} "
82
+ )
83
+ assert pt_ref_diff <= 0.5 , (
84
+ f"Max difference pt vs ref too large: { pt_ref_diff :.6f} > 0.5, "
85
+ f"gap={ pt_ref_diff - 0.5 :.6f} "
86
+ )
68
87
69
88
def _generate_qkv (
70
89
self ,
@@ -121,6 +140,7 @@ def _execute_cutlass_blackwell_attn_dense(
121
140
) -> None :
122
141
device = torch .accelerator .current_accelerator ()
123
142
assert device is not None
143
+ torch .manual_seed (SEED )
124
144
assert seqlen_q <= seqlen_k
125
145
126
146
# Initialize deterministic variables
@@ -263,6 +283,8 @@ def _execute_cutlass_blackwell_attn_varlen(
263
283
device = torch .accelerator .current_accelerator ()
264
284
assert device is not None
265
285
286
+ torch .manual_seed (SEED )
287
+
266
288
# Initialize deterministic variables
267
289
out_unpad_d = None
268
290
q_ref , k_ref , v_ref = self ._generate_qkv (
@@ -501,6 +523,8 @@ def test_jagged_vs_padded_kv(
501
523
head_dim = 128
502
524
dtype = torch .bfloat16
503
525
526
+ torch .manual_seed (SEED )
527
+
504
528
# Create tensors
505
529
q_padded = torch .randn (
506
530
batch_size ,
0 commit comments