@@ -1145,3 +1145,114 @@ def test_quantized_relu(
11451145            torch .equal (output , expected_output ),
11461146            f"Output values don't match expected in { name } { output } { expected_output }  ,
11471147        )
1148+ 
1149+     @expand ( 
1150+         [ 
1151+             ( 
1152+                 "basic_2d" , 
1153+                 torch .tensor ( 
1154+                     [[1.0 , 2.0 , 3.0 , 4.0 ]], dtype = torch .float32  
1155+                 ),  # input: [1, 4]  
1156+                 torch .tensor ( 
1157+                     [[0.0 , 0.0 ]], dtype = torch .float32  
1158+                 ),  # sin: [1, 2] - broadcasts to [1, 4]  
1159+                 torch .tensor ( 
1160+                     [[1.0 , 1.0 ]], dtype = torch .float32  
1161+                 ),  # cos: [1, 2] - broadcasts to [1, 4]  
1162+                 torch .tensor ( 
1163+                     [[1.0 , 3.0 , 2.0 , 4.0 ]], dtype = torch .float32  
1164+                 ),  # expected: [1, 3, 2, 4]  
1165+             ), 
1166+             ( 
1167+                 "batch_sequence_3d" , 
1168+                 torch .tensor ( 
1169+                     [[[1.0 , 0.0 , 2.0 , 0.0 ]]], dtype = torch .float32  
1170+                 ),  # input: [1, 1, 4]  
1171+                 torch .tensor ([[[0.5 , 0.5 ]]], dtype = torch .float32 ),  # sin: [1, 1, 2]  
1172+                 torch .tensor ( 
1173+                     [[[0.866 , 0.866 ]]], dtype = torch .float32  
1174+                 ),  # cos: [1, 1, 2] (approx cos(30°))  
1175+                 torch .tensor ( 
1176+                     [[[0.866 , 1.732 , 0.5 , 1.0 ]]], dtype = torch .float32  
1177+                 ),  # expected: [0.866, 1.732, 0.5, 1.0]  
1178+             ), 
1179+             ( 
1180+                 "multiple_batch" , 
1181+                 torch .tensor ( 
1182+                     [[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = torch .float32  
1183+                 ),  # input: [2, 2]  
1184+                 torch .tensor ( 
1185+                     [[0.0 ], [1.0 ]], dtype = torch .float32  
1186+                 ),  # sin: [2, 1] - broadcasts to [2, 2]  
1187+                 torch .tensor ( 
1188+                     [[1.0 ], [0.0 ]], dtype = torch .float32  
1189+                 ),  # cos: [2, 1] - broadcasts to [2, 2]  
1190+                 torch .tensor ( 
1191+                     [[1.0 , 2.0 ], [- 4.0 , 3.0 ]], dtype = torch .float32  
1192+                 ),  # expected: [[1, 2], [-4, 3]]  
1193+             ), 
1194+             ( 
1195+                 "larger_embedding" , 
1196+                 torch .tensor ( 
1197+                     [[1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ]], dtype = torch .float32  
1198+                 ),  # input: [1, 6]  
1199+                 torch .tensor ([[0.0 , 0.5 , 1.0 ]], dtype = torch .float32 ),  # sin: [1, 3]  
1200+                 torch .tensor ([[1.0 , 0.866 , 0.0 ]], dtype = torch .float32 ),  # cos: [1, 3]  
1201+                 torch .tensor ( 
1202+                     [[1.0 , 0.598 , - 6.0 , 2.0 , 4.964 , 5.0 ]], dtype = torch .float32  
1203+                 ),  # expected: [1, 0.598, -6, 2, 4.964, 5]  
1204+             ), 
1205+             ( 
1206+                 "single_pair" , 
1207+                 torch .tensor ([[1.0 , 2.0 ]], dtype = torch .float32 ),  # input: [1, 2]  
1208+                 torch .tensor ([[0.707 ]], dtype = torch .float32 ),  # sin: [1, 1] (sin(45°))  
1209+                 torch .tensor ([[0.707 ]], dtype = torch .float32 ),  # cos: [1, 1] (cos(45°))  
1210+                 torch .tensor ( 
1211+                     [[- 0.707 , 2.121 ]], dtype = torch .float32  
1212+                 ),  # expected: [-0.707, 2.121]  
1213+             ), 
1214+             ( 
1215+                 "pos is not None" , 
1216+                 torch .tensor (0 ), 
1217+                 torch .tensor (0 ), 
1218+                 torch .tensor (0 ), 
1219+                 torch .tensor (0 ), 
1220+                 torch .tensor (0 ),  # pos is not None  
1221+             ), 
1222+         ] 
1223+     ) 
1224+     def  test_rope (
1225+         self ,
1226+         name : str ,
1227+         input_tensor : torch .Tensor ,
1228+         sin_tensor : torch .Tensor ,
1229+         cos_tensor : torch .Tensor ,
1230+         expected_output : torch .Tensor ,
1231+         pos : torch .Tensor  |  None  =  None ,
1232+     ) ->  None :
1233+         if  pos  is  not None :
1234+             with  self .assertRaises (ValueError ) as  context :
1235+                 torch .ops .cadence .rope (input_tensor , sin_tensor , cos_tensor , pos )
1236+ 
1237+             self .assertIn ("pos is not supported" , str (context .exception ))
1238+             return 
1239+ 
1240+         output  =  torch .ops .cadence .rope (input_tensor , sin_tensor , cos_tensor , None )
1241+ 
1242+         # Verify output properties 
1243+         self .assertEqual (
1244+             output .dtype ,
1245+             input_tensor .dtype ,
1246+             f"Output dtype should match input dtype in { name }  ,
1247+         )
1248+         self .assertEqual (
1249+             output .shape ,
1250+             input_tensor .shape ,
1251+             f"Output shape should match input shape in { name }  ,
1252+         )
1253+ 
1254+         # Verify output matches expected values 
1255+         self .assertTrue (
1256+             torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1257+             f"Output values don't match expected in { name } { output } { expected_output }  ,
1258+         )
0 commit comments