1313BLOCK_N = 64
1414
1515def pytorch_test (Q , K , V , block_sparse_mask , dO ):
16- q_ = Q .clone ().requires_grad_ ()
17- k_ = K .clone ().requires_grad_ ()
18- v_ = V .clone ().requires_grad_ ()
16+ q_ = Q .clone ().float (). requires_grad_ ()
17+ k_ = K .clone ().float (). requires_grad_ ()
18+ v_ = V .clone ().float (). requires_grad_ ()
1919
2020 QK = torch .matmul (q_ , k_ .transpose (- 2 , - 1 ))
2121 QK /= (q_ .size (- 1 ) ** 0.5 )
@@ -35,9 +35,9 @@ def pytorch_test(Q, K, V, block_sparse_mask, dO):
3535
3636
3737def block_sparse_kernel_test (Q , K , V , block_sparse_mask , variable_block_sizes , non_pad_index , dO ):
38- Q = Q .clone ().requires_grad_ ()
39- K = K .clone ().requires_grad_ ()
40- V = V .clone ().requires_grad_ ()
38+ Q = Q .detach ().requires_grad_ ()
39+ K = K .detach ().requires_grad_ ()
40+ V = V .detach ().requires_grad_ ()
4141
4242 q_padded = vsa_pad (Q , non_pad_index , variable_block_sizes .shape [0 ], BLOCK_M )
4343 k_padded = vsa_pad (K , non_pad_index , variable_block_sizes .shape [0 ], BLOCK_M )
@@ -60,11 +60,9 @@ def get_non_pad_index(
6060
6161 return index_pad [index_mask ]
6262
63- def generate_tensor (shape , mean , std , dtype , device ):
63+ def generate_tensor (shape , dtype , device ):
6464 tensor = torch .randn (shape , dtype = dtype , device = device )
65- magnitude = torch .norm (tensor , dim = - 1 , keepdim = True )
66- scaled_tensor = tensor * (torch .randn (magnitude .shape , dtype = dtype , device = device ) * std + mean ) / magnitude
67- return scaled_tensor .contiguous ()
65+ return tensor
6866
6967def generate_variable_block_sizes (num_blocks , min_size = 32 , max_size = 64 , device = "cuda" ):
7068 return torch .randint (min_size , max_size + 1 , (num_blocks ,), device = device , dtype = torch .int32 )
@@ -75,7 +73,7 @@ def vsa_pad(x, non_pad_index, num_blocks, block_size):
7573 padded_x [:, :, non_pad_index , :] = x
7674 return padded_x
7775
78- def check_correctness (h , d , num_blocks , k , mean , std , num_iterations = 20 , error_mode = 'all' ):
76+ def check_correctness (h , d , num_blocks , k , num_iterations = 20 , error_mode = 'all' ):
7977 results = {
8078 'gO' : {'sum_diff' : 0.0 , 'sum_abs' : 0.0 , 'max_diff' : 0.0 },
8179 'gQ' : {'sum_diff' : 0.0 , 'sum_abs' : 0.0 , 'max_diff' : 0.0 },
@@ -91,10 +89,10 @@ def check_correctness(h, d, num_blocks, k, mean, std, num_iterations=20, error_m
9189 block_mask = generate_block_sparse_mask_for_function (h , num_blocks , k , device )
9290 full_mask = create_full_mask_from_block_mask (block_mask , variable_block_sizes , device )
9391 for _ in range (num_iterations ):
94- Q = generate_tensor ((1 , h , S , d ), mean , std , torch .bfloat16 , device )
95- K = generate_tensor ((1 , h , S , d ), mean , std , torch .bfloat16 , device )
96- V = generate_tensor ((1 , h , S , d ), mean , std , torch .bfloat16 , device )
97- dO = generate_tensor ((1 , h , S , d ), mean , std , torch .bfloat16 , device )
92+ Q = generate_tensor ((1 , h , S , d ), torch .bfloat16 , device )
93+ K = generate_tensor ((1 , h , S , d ), torch .bfloat16 , device )
94+ V = generate_tensor ((1 , h , S , d ), torch .bfloat16 , device )
95+ dO = generate_tensor ((1 , h , S , d ), torch .bfloat16 , device )
9896
9997 # dO_padded = torch.zeros_like(dO_padded)
10098 # dO_padded[:, :, non_pad_index, :] = dO
@@ -107,7 +105,8 @@ def check_correctness(h, d, num_blocks, k, mean, std, num_iterations=20, error_m
107105 abs_diff = torch .abs (diff )
108106 results [name ]['sum_diff' ] += torch .sum (abs_diff ).item ()
109107 results [name ]['sum_abs' ] += torch .sum (torch .abs (pt )).item ()
110- results [name ]['max_diff' ] = max (results [name ]['max_diff' ], torch .max (abs_diff ).item ())
108+ rel_max_diff = torch .max (abs_diff ) / torch .mean (torch .abs (pt ))
109+ results [name ]['max_diff' ] = max (results [name ]['max_diff' ], rel_max_diff .item ())
111110 if torch .cuda .is_available ():
112111 torch .cuda .empty_cache ()
113112
@@ -119,27 +118,27 @@ def check_correctness(h, d, num_blocks, k, mean, std, num_iterations=20, error_m
119118
120119 return results
121120
122- def generate_error_graphs (h , d , mean , std , error_mode = 'all' ):
121+ def generate_error_graphs (h , d , error_mode = 'all' ):
123122 test_configs = [
124123 {"num_blocks" : 16 , "k" : 2 , "description" : "Small sequence" },
125124 {"num_blocks" : 32 , "k" : 4 , "description" : "Medium sequence" },
126125 {"num_blocks" : 53 , "k" : 6 , "description" : "Large sequence" },
127126 ]
128127
129- print (f"\n Error Analysis for h={ h } , d={ d } , mean= { mean } , std= { std } , mode={ error_mode } " )
128+ print (f"\n Error Analysis for h={ h } , d={ d } , mode={ error_mode } " )
130129 print ("=" * 150 )
131130 print (f"{ 'Config' :<20} { 'Blocks' :<8} { 'K' :<4} "
132- f"{ 'gQ Avg' :<12} { 'gQ Max' :<12} "
133- f"{ 'gK Avg' :<12} { 'gK Max' :<12} "
134- f"{ 'gV Avg' :<12} { 'gV Max' :<12} "
135- f"{ 'gO Avg' :<12} { 'gO Max' :<12} " )
131+ f"{ 'gQ Avg' :<12} { 'Rel gQ Max' :<12} "
132+ f"{ 'gK Avg' :<12} { 'Rel gK Max' :<12} "
133+ f"{ 'gV Avg' :<12} { 'Rel gV Max' :<12} "
134+ f"{ 'gO Avg' :<12} { 'Rel gO Max' :<12} " )
136135 print ("-" * 150 )
137136
138137 for config in test_configs :
139138 num_blocks = config ["num_blocks" ]
140139 k = config ["k" ]
141140 description = config ["description" ]
142- results = check_correctness (h , d , num_blocks , k , mean , std , error_mode = error_mode )
141+ results = check_correctness (h , d , num_blocks , k , error_mode = error_mode )
143142 print (f"{ description :<20} { num_blocks :<8} { k :<4} "
144143 f"{ results ['gQ' ]['avg_diff' ]:<12.6e} { results ['gQ' ]['max_diff' ]:<12.6e} "
145144 f"{ results ['gK' ]['avg_diff' ]:<12.6e} { results ['gK' ]['max_diff' ]:<12.6e} "
@@ -150,10 +149,8 @@ def generate_error_graphs(h, d, mean, std, error_mode='all'):
150149
151150if __name__ == "__main__" :
152151 h , d = 16 , 128
153- mean = 0.0
154- std = 1
155152 print ("Block Sparse Attention with Variable Block Sizes Analysis" )
156153 print ("=" * 60 )
157154 for mode in ['backward' ]:
158- generate_error_graphs (h , d , mean , std , error_mode = mode )
155+ generate_error_graphs (h , d , error_mode = mode )
159156 print ("\n Analysis completed for all modes." )
0 commit comments