@@ -1201,11 +1201,30 @@ def test_forward_scriptability(self):
12011201 torch .jit .script (ops .DeformConv2d (in_channels = 8 , out_channels = 8 , kernel_size = 3 ))
12021202
12031203
1204+ # NS: Remove me once backward is implemented for MPS
1205+ def xfail_if_mps (x ):
1206+ mps_xfail_param = pytest .param ("mps" , marks = (pytest .mark .needs_mps , pytest .mark .xfail ))
1207+ new_pytestmark = []
1208+ for mark in x .pytestmark :
1209+ if isinstance (mark , pytest .Mark ) and mark .name == "parametrize" :
1210+ if mark .args [0 ] == "device" :
1211+ params = cpu_and_cuda () + (mps_xfail_param ,)
1212+ new_pytestmark .append (pytest .mark .parametrize ("device" , params ))
1213+ continue
1214+ new_pytestmark .append (mark )
1215+ x .__dict__ ["pytestmark" ] = new_pytestmark
1216+ return x
1217+
1218+
12041219optests .generate_opcheck_tests (
12051220 testcase = TestDeformConv ,
12061221 namespaces = ["torchvision" ],
12071222 failures_dict_path = os .path .join (os .path .dirname (__file__ ), "optests_failures_dict.json" ),
1208- additional_decorators = [],
1223+ # Skip tests due to unimplemented backward
1224+ additional_decorators = {
1225+ "test_aot_dispatch_dynamic__test_forward" : [xfail_if_mps ],
1226+ "test_autograd_registration__test_forward" : [xfail_if_mps ],
1227+ },
12091228 test_utils = OPTESTS ,
12101229)
12111230
0 commit comments