@@ -43,16 +43,23 @@ def test_mrope_triton_correctness(B, H_q, H_k, L, D, mrope_section):
4343
4444 torch .manual_seed (0 )
4545 device = "cuda"
46+ HALF = D // 2
4647
4748 q = torch .rand ((B , H_q , L , D ), dtype = torch .float32 , device = device )
4849 k = torch .rand ((B , H_k , L , D ), dtype = torch .float32 , device = device )
49- cos = torch .rand ((3 , 1 , L , D ), dtype = torch .float32 , device = device )
50- sin = torch .rand ((3 , 1 , L , D ), dtype = torch .float32 , device = device )
5150
52- ref_q , ref_k = apply_multimodal_rotary_pos_emb (q , k , cos , sin , mrope_section , unsqueeze_dim = 1 )
51+ cos_half = torch .rand ((3 , L , HALF ), dtype = torch .float32 , device = device )
52+ sin_half = torch .rand ((3 , L , HALF ), dtype = torch .float32 , device = device )
5353
54- out_q , out_k = mrope_triton (q , k , cos , sin , axis_map )
54+ cos_full = torch .cat ([cos_half , cos_half ], dim = - 1 )
55+ sin_full = torch .cat ([sin_half , sin_half ], dim = - 1 )
5556
57+ cos_ref = cos_full .unsqueeze (1 )
58+ sin_ref = sin_full .unsqueeze (1 )
59+
60+ ref_q , ref_k = apply_multimodal_rotary_pos_emb (q , k , cos_ref , sin_ref , mrope_section , unsqueeze_dim = 1 )
61+
62+ out_q , out_k = mrope_triton (q , k , cos_half , sin_half , axis_map )
5663 assert torch .allclose (out_q , ref_q , rtol = 1e-3 , atol = 1e-3 )
5764 assert torch .allclose (out_k , ref_k , rtol = 1e-3 , atol = 1e-3 )
5865
0 commit comments