@@ -1145,3 +1145,114 @@ def test_quantized_relu(
11451145 torch .equal (output , expected_output ),
11461146 f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
11471147 )
1148+
1149+ @expand (
1150+ [
1151+ (
1152+ "basic_2d" ,
1153+ torch .tensor (
1154+ [[1.0 , 2.0 , 3.0 , 4.0 ]], dtype = torch .float32
1155+ ), # input: [1, 4]
1156+ torch .tensor (
1157+ [[0.0 , 0.0 ]], dtype = torch .float32
1158+ ), # sin: [1, 2] - broadcasts to [1, 4]
1159+ torch .tensor (
1160+ [[1.0 , 1.0 ]], dtype = torch .float32
1161+ ), # cos: [1, 2] - broadcasts to [1, 4]
1162+ torch .tensor (
1163+ [[1.0 , 3.0 , 2.0 , 4.0 ]], dtype = torch .float32
1164+ ), # expected: [1, 3, 2, 4]
1165+ ),
1166+ (
1167+ "batch_sequence_3d" ,
1168+ torch .tensor (
1169+ [[[1.0 , 0.0 , 2.0 , 0.0 ]]], dtype = torch .float32
1170+ ), # input: [1, 1, 4]
1171+ torch .tensor ([[[0.5 , 0.5 ]]], dtype = torch .float32 ), # sin: [1, 1, 2]
1172+ torch .tensor (
1173+ [[[0.866 , 0.866 ]]], dtype = torch .float32
1174+ ), # cos: [1, 1, 2] (approx cos(30°))
1175+ torch .tensor (
1176+ [[[0.866 , 1.732 , 0.5 , 1.0 ]]], dtype = torch .float32
1177+ ), # expected: [0.866, 1.732, 0.5, 1.0]
1178+ ),
1179+ (
1180+ "multiple_batch" ,
1181+ torch .tensor (
1182+ [[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = torch .float32
1183+ ), # input: [2, 2]
1184+ torch .tensor (
1185+ [[0.0 ], [1.0 ]], dtype = torch .float32
1186+ ), # sin: [2, 1] - broadcasts to [2, 2]
1187+ torch .tensor (
1188+ [[1.0 ], [0.0 ]], dtype = torch .float32
1189+ ), # cos: [2, 1] - broadcasts to [2, 2]
1190+ torch .tensor (
1191+ [[1.0 , 2.0 ], [- 4.0 , 3.0 ]], dtype = torch .float32
1192+ ), # expected: [[1, 2], [-4, 3]]
1193+ ),
1194+ (
1195+ "larger_embedding" ,
1196+ torch .tensor (
1197+ [[1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ]], dtype = torch .float32
1198+ ), # input: [1, 6]
1199+ torch .tensor ([[0.0 , 0.5 , 1.0 ]], dtype = torch .float32 ), # sin: [1, 3]
1200+ torch .tensor ([[1.0 , 0.866 , 0.0 ]], dtype = torch .float32 ), # cos: [1, 3]
1201+ torch .tensor (
1202+ [[1.0 , 0.598 , - 6.0 , 2.0 , 4.964 , 5.0 ]], dtype = torch .float32
1203+ ), # expected: [1, 0.598, -6, 2, 4.964, 5]
1204+ ),
1205+ (
1206+ "single_pair" ,
1207+ torch .tensor ([[1.0 , 2.0 ]], dtype = torch .float32 ), # input: [1, 2]
1208+ torch .tensor ([[0.707 ]], dtype = torch .float32 ), # sin: [1, 1] (sin(45°))
1209+ torch .tensor ([[0.707 ]], dtype = torch .float32 ), # cos: [1, 1] (cos(45°))
1210+ torch .tensor (
1211+ [[- 0.707 , 2.121 ]], dtype = torch .float32
1212+ ), # expected: [-0.707, 2.121]
1213+ ),
1214+ (
1215+ "pos is not None" ,
1216+ torch .tensor (0 ),
1217+ torch .tensor (0 ),
1218+ torch .tensor (0 ),
1219+ torch .tensor (0 ),
1220+ torch .tensor (0 ), # pos is not None
1221+ ),
1222+ ]
1223+ )
1224+ def test_rope (
1225+ self ,
1226+ name : str ,
1227+ input_tensor : torch .Tensor ,
1228+ sin_tensor : torch .Tensor ,
1229+ cos_tensor : torch .Tensor ,
1230+ expected_output : torch .Tensor ,
1231+ pos : torch .Tensor | None = None ,
1232+ ) -> None :
1233+ if pos is not None :
1234+ with self .assertRaises (ValueError ) as context :
1235+ torch .ops .cadence .rope (input_tensor , sin_tensor , cos_tensor , pos )
1236+
1237+ self .assertIn ("pos is not supported" , str (context .exception ))
1238+ return
1239+
1240+ output = torch .ops .cadence .rope (input_tensor , sin_tensor , cos_tensor , None )
1241+
1242+ # Verify output properties
1243+ self .assertEqual (
1244+ output .dtype ,
1245+ input_tensor .dtype ,
1246+ f"Output dtype should match input dtype in { name } " ,
1247+ )
1248+ self .assertEqual (
1249+ output .shape ,
1250+ input_tensor .shape ,
1251+ f"Output shape should match input shape in { name } " ,
1252+ )
1253+
1254+ # Verify output matches expected values
1255+ self .assertTrue (
1256+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1257+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1258+ )
0 commit comments