@@ -392,17 +392,50 @@ def setUp(self):
392392 self .max_seq_len = 2048
393393 self .setup_caches ()
394394
395+ def _scale_tensor (self , tensor , min_value , max_value , scale = True ):
396+ normalized_tensor = (tensor - tensor .min ()) / (tensor .max () - tensor .min ())
397+
398+ scaled_tensor = normalized_tensor * (max_value - min_value ) + min_value
399+
400+ return scaled_tensor if scale else tensor
401+
395402 def _test_sdpa_common (
396- self , n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len , next_iter_seq_len = 1
403+ self ,
404+ n_heads_kv ,
405+ n_heads_q ,
406+ head_dim ,
407+ max_seq_len ,
408+ seq_len ,
409+ next_iter_seq_len = 1 ,
410+ scale_tensors = False ,
397411 ):
412+ # Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests
413+ tensor_scale_max = 20
414+ tensor_scale_min = - 20
398415 self .n_heads_kv = n_heads_kv
399416 self .n_heads_q = n_heads_q
400417 self .head_dim = head_dim
401418 self .max_seq_len = max_seq_len
402419 self .setup_caches ()
403- q = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
404- k = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
405- v = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
420+ q = self ._scale_tensor (
421+ torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim )),
422+ tensor_scale_max ,
423+ tensor_scale_min ,
424+ scale_tensors ,
425+ )
426+ k = self ._scale_tensor (
427+ torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim )),
428+ tensor_scale_max ,
429+ tensor_scale_min ,
430+ scale_tensors ,
431+ )
432+ v = self ._scale_tensor (
433+ torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim )),
434+ tensor_scale_max ,
435+ tensor_scale_min ,
436+ scale_tensors ,
437+ )
438+
406439 start_pos = 0
407440 attn_mask = self .mask [start_pos : start_pos + seq_len , :]
408441 attn_mask = attn_mask [:, : start_pos + seq_len ]
@@ -412,11 +445,27 @@ def _test_sdpa_common(
412445 op_output = torch .ops .llama .sdpa_with_kv_cache (
413446 q , k , v , self .k_cache , self .v_cache , start_pos , seq_len , None , 0 , True
414447 )
415- self .assertTrue (torch .allclose (ref_output , op_output ))
448+ self .assertTrue (torch .allclose (ref_output , op_output , atol = 1e-6 ))
449+
450+ q = self ._scale_tensor (
451+ torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim )),
452+ tensor_scale_max ,
453+ tensor_scale_min ,
454+ scale_tensors ,
455+ )
456+ k = self ._scale_tensor (
457+ torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim )),
458+ tensor_scale_max ,
459+ tensor_scale_min ,
460+ scale_tensors ,
461+ )
462+ v = self ._scale_tensor (
463+ torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim )),
464+ tensor_scale_max ,
465+ tensor_scale_min ,
466+ scale_tensors ,
467+ )
416468
417- q = torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim ))
418- k = torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim ))
419- v = torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim ))
420469 start_pos = seq_len
421470 seq_len = q .size (1 )
422471 attn_mask = self .mask [start_pos : start_pos + seq_len , :]
@@ -427,7 +476,7 @@ def _test_sdpa_common(
427476 op_output = torch .ops .llama .sdpa_with_kv_cache (
428477 q , k , v , self .k_cache , self .v_cache , start_pos , seq_len , None , 0 , True
429478 )
430- self .assertTrue (torch .allclose (ref_output , op_output ))
479+ self .assertTrue (torch .allclose (ref_output , op_output , atol = 1e-6 ))
431480
432481
433482class SDPATestForLargeSeqLength (SDPATestCommon ):
@@ -438,7 +487,9 @@ def test_sdpa_with_cache_seq_len_130(self):
438487 head_dim = 128
439488 max_seq_len = 2048
440489 seq_len = 130
441- self ._test_sdpa_common (n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len )
490+ self ._test_sdpa_common (
491+ n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len , True
492+ )
442493
443494 def test_sdpa_with_cache_seq_len_small (self ):
444495 n_heads_kv = 4
@@ -462,7 +513,9 @@ def test_sdpa_with_cache_seq_len_130_gqa(self):
462513 head_dim = 128
463514 max_seq_len = 2048
464515 seq_len = 130
465- self ._test_sdpa_common (n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len )
516+ self ._test_sdpa_common (
517+ n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len , True
518+ )
466519
467520 def test_sdpa_with_cache_seq_len_llava_example_gqa (self ):
468521 n_heads_kv = 16
@@ -483,7 +536,13 @@ def test_sdpa_with_cache_seq_len_130(self):
483536 seq_len = 130
484537 next_iter_seq_len = 17
485538 self ._test_sdpa_common (
486- n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len , next_iter_seq_len
539+ n_heads_kv ,
540+ n_heads_q ,
541+ head_dim ,
542+ max_seq_len ,
543+ seq_len ,
544+ next_iter_seq_len ,
545+ True ,
487546 )
488547
489548 def test_sdpa_with_cache_seq_len_llava_example (self ):
@@ -505,7 +564,13 @@ def test_sdpa_with_cache_seq_len_130_gqa(self):
505564 seq_len = 130
506565 next_iter_seq_len = 33
507566 self ._test_sdpa_common (
508- n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len , next_iter_seq_len
567+ n_heads_kv ,
568+ n_heads_q ,
569+ head_dim ,
570+ max_seq_len ,
571+ seq_len ,
572+ next_iter_seq_len ,
573+ True ,
509574 )
510575
511576 def test_sdpa_with_cache_seq_len_llava_example_gqa (self ):
0 commit comments