@@ -1156,3 +1156,99 @@ def test_where_Scalar(self) -> None:
11561156 torch .ops .cadence .where_Scalar (input_tensor , 1.0 , 0.0 )
11571157
11581158 self .assertIn ("condition must be a bool tensor" , str (context .exception ))
1159+
1160+ @expand (
1161+ [
1162+ (
1163+ "h1xhd4" ,
1164+ torch .tensor (
1165+ [[[[1.0 , 2.0 , 3.0 , 4.0 ]]]], dtype = torch .float32
1166+ ),
1167+ torch .tensor (
1168+ [[0.0 , 0.0 ]], dtype = torch .float32
1169+ ),
1170+ torch .tensor (
1171+ [[1.0 , 1.0 ]], dtype = torch .float32
1172+ ),
1173+ torch .tensor (
1174+ [[[[1.0 , 3.0 , 2.0 , 4.0 ]]]], dtype = torch .float32
1175+ ),
1176+ ),
1177+ (
1178+ "h2xhd4" ,
1179+ torch .tensor (
1180+ [[[[1.0 , 2.0 , 3.0 , 4.0 ], [5.0 , 6.0 , 7.0 , 8.0 ]]]], dtype = torch .float32
1181+ ),
1182+ torch .tensor (
1183+ [[0.0 , 1.0 ]], dtype = torch .float32
1184+ ),
1185+ torch .tensor (
1186+ [[1.0 , 0.0 ]], dtype = torch .float32
1187+ ),
1188+ torch .tensor (
1189+ [[[[1.0 , - 4.0 , 2.0 , 3.0 ], [5 , - 8.0 , 6.0 , 7.0 ]]]], dtype = torch .float32
1190+ ),
1191+ ),
1192+ (
1193+ "s2xh2xhd4" ,
1194+ torch .tensor (
1195+ [[[[1.0 , 2.0 , 3.0 , 4.0 ], [5.0 , 6.0 , 7.0 , 8.0 ]], [[9.0 , 10.0 , 11.0 , 12.0 ], [13.0 , 14.0 , 15.0 , 16.0 ]]]], dtype = torch .float32
1196+ ),
1197+ torch .tensor (
1198+ [[0.0 , 1.0 ], [0.0 , 1.0 ]], dtype = torch .float32
1199+ ),
1200+ torch .tensor (
1201+ [[1.0 , 0.0 ], [1.0 , 0.0 ]], dtype = torch .float32
1202+ ),
1203+ torch .tensor ([[[[ 1. , - 4. , 2. , 3. ],
1204+ [ 5. , - 8. , 6. , 7. ]],
1205+ [[ 9. , - 12. , 10. , 11. ],
1206+ [ 13. , - 16. , 14. , 15. ]]]], dtype = torch .float32 ),
1207+ ),
1208+ (
1209+ "pos_not_none" ,
1210+ torch .tensor (
1211+ [[[[1.0 , 2.0 , 3.0 , 4.0 ], [5.0 , 6.0 , 7.0 , 8.0 ]], [[9.0 , 10.0 , 11.0 , 12.0 ], [13.0 , 14.0 , 15.0 , 16.0 ]]]], dtype = torch .float32
1212+ ),
1213+ torch .tensor (
1214+ [[1.0 , 0.0 ], [0.0 , 1.0 ]], dtype = torch .float32
1215+ ),
1216+ torch .tensor (
1217+ [[0.0 , 1.0 ], [1.0 , 0.0 ]], dtype = torch .float32
1218+ ),
1219+ torch .tensor ([[[[ 1. , - 4. , 2. , 3. ],
1220+ [ 5. , - 8. , 6. , 7. ]],
1221+ [[ - 10. , 11. , 9. , 12. ],
1222+ [ - 14. , 15. , 13. , 16. ]]]], dtype = torch .float32 ),
1223+ torch .tensor ([1 , 0 ])
1224+ ),
1225+ ]
1226+ )
1227+ def test_rope (
1228+ self ,
1229+ name : str ,
1230+ input_tensor : torch .Tensor ,
1231+ sin_tensor : torch .Tensor ,
1232+ cos_tensor : torch .Tensor ,
1233+ expected_output : torch .Tensor ,
1234+ pos : torch .Tensor | None = None ,
1235+ ) -> None :
1236+ output = torch .ops .cadence .rope (input_tensor , sin_tensor , cos_tensor , pos )
1237+
1238+ # Verify output properties
1239+ self .assertEqual (
1240+ output .dtype ,
1241+ input_tensor .dtype ,
1242+ f"Output dtype should match input dtype in { name } " ,
1243+ )
1244+ self .assertEqual (
1245+ output .shape ,
1246+ input_tensor .shape ,
1247+ f"Output shape should match input shape in { name } " ,
1248+ )
1249+
1250+ # Verify output matches expected values
1251+ self .assertTrue (
1252+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1253+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1254+ )
0 commit comments