@@ -1156,3 +1156,103 @@ def test_where_Scalar(self) -> None:
11561156 torch .ops .cadence .where_Scalar (input_tensor , 1.0 , 0.0 )
11571157
11581158 self .assertIn ("condition must be a bool tensor" , str (context .exception ))
1159+
1160+ @expand (
1161+ [
1162+ (
1163+ "h1xhd4" ,
1164+ torch .tensor ([[[[1.0 , 2.0 , 3.0 , 4.0 ]]]], dtype = torch .float32 ),
1165+ torch .tensor ([[0.0 , 0.0 ]], dtype = torch .float32 ),
1166+ torch .tensor ([[1.0 , 1.0 ]], dtype = torch .float32 ),
1167+ torch .tensor ([[[[1.0 , 3.0 , 2.0 , 4.0 ]]]], dtype = torch .float32 ),
1168+ ),
1169+ (
1170+ "h2xhd4" ,
1171+ torch .tensor (
1172+ [[[[1.0 , 2.0 , 3.0 , 4.0 ], [5.0 , 6.0 , 7.0 , 8.0 ]]]],
1173+ dtype = torch .float32 ,
1174+ ),
1175+ torch .tensor ([[0.0 , 1.0 ]], dtype = torch .float32 ),
1176+ torch .tensor ([[1.0 , 0.0 ]], dtype = torch .float32 ),
1177+ torch .tensor (
1178+ [[[[1.0 , - 4.0 , 2.0 , 3.0 ], [5 , - 8.0 , 6.0 , 7.0 ]]]],
1179+ dtype = torch .float32 ,
1180+ ),
1181+ ),
1182+ (
1183+ "s2xh2xhd4" ,
1184+ torch .tensor (
1185+ [
1186+ [
1187+ [[1.0 , 2.0 , 3.0 , 4.0 ], [5.0 , 6.0 , 7.0 , 8.0 ]],
1188+ [[9.0 , 10.0 , 11.0 , 12.0 ], [13.0 , 14.0 , 15.0 , 16.0 ]],
1189+ ]
1190+ ],
1191+ dtype = torch .float32 ,
1192+ ),
1193+ torch .tensor ([[0.0 , 1.0 ], [0.0 , 1.0 ]], dtype = torch .float32 ),
1194+ torch .tensor ([[1.0 , 0.0 ], [1.0 , 0.0 ]], dtype = torch .float32 ),
1195+ torch .tensor (
1196+ [
1197+ [
1198+ [[1.0 , - 4.0 , 2.0 , 3.0 ], [5.0 , - 8.0 , 6.0 , 7.0 ]],
1199+ [[9.0 , - 12.0 , 10.0 , 11.0 ], [13.0 , - 16.0 , 14.0 , 15.0 ]],
1200+ ]
1201+ ],
1202+ dtype = torch .float32 ,
1203+ ),
1204+ ),
1205+ (
1206+ "pos_not_none" ,
1207+ torch .tensor (
1208+ [
1209+ [
1210+ [[1.0 , 2.0 , 3.0 , 4.0 ], [5.0 , 6.0 , 7.0 , 8.0 ]],
1211+ [[9.0 , 10.0 , 11.0 , 12.0 ], [13.0 , 14.0 , 15.0 , 16.0 ]],
1212+ ]
1213+ ],
1214+ dtype = torch .float32 ,
1215+ ),
1216+ torch .tensor ([[1.0 , 0.0 ], [0.0 , 1.0 ]], dtype = torch .float32 ),
1217+ torch .tensor ([[0.0 , 1.0 ], [1.0 , 0.0 ]], dtype = torch .float32 ),
1218+ torch .tensor (
1219+ [
1220+ [
1221+ [[1.0 , - 4.0 , 2.0 , 3.0 ], [5.0 , - 8.0 , 6.0 , 7.0 ]],
1222+ [[- 10.0 , 11.0 , 9.0 , 12.0 ], [- 14.0 , 15.0 , 13.0 , 16.0 ]],
1223+ ]
1224+ ],
1225+ dtype = torch .float32 ,
1226+ ),
1227+ torch .tensor ([1 , 0 ]),
1228+ ),
1229+ ]
1230+ )
1231+ def test_rope (
1232+ self ,
1233+ name : str ,
1234+ input_tensor : torch .Tensor ,
1235+ sin_tensor : torch .Tensor ,
1236+ cos_tensor : torch .Tensor ,
1237+ expected_output : torch .Tensor ,
1238+ pos : torch .Tensor | None = None ,
1239+ ) -> None :
1240+ output = torch .ops .cadence .rope (input_tensor , sin_tensor , cos_tensor , pos )
1241+
1242+ # Verify output properties
1243+ self .assertEqual (
1244+ output .dtype ,
1245+ input_tensor .dtype ,
1246+ f"Output dtype should match input dtype in { name } " ,
1247+ )
1248+ self .assertEqual (
1249+ output .shape ,
1250+ input_tensor .shape ,
1251+ f"Output shape should match input shape in { name } " ,
1252+ )
1253+
1254+ # Verify output matches expected values
1255+ self .assertTrue (
1256+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1257+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1258+ )
0 commit comments