@@ -1565,6 +1565,7 @@ def _where_input_wrangler(
15651565 "nn.functional.conv1d" ,
15661566 core_ops .aten_conv1d_complex ,
15671567 complex = True ,
1568+ tolerance = {torch .complex64 : (2e-5 , 3e-5 )},
15681569 ).xfail (
15691570 matcher = lambda sample : isinstance (sample .kwargs .get ("padding" ), str ),
15701571 reason = "String padding is not accepted by aten::conv1d" ,
@@ -1581,7 +1582,7 @@ def _where_input_wrangler(
15811582 "nn.functional.conv2d" ,
15821583 core_ops .aten_conv2d_complex ,
15831584 complex = True ,
1584- tolerance = {torch .float32 : (2e-5 , 3e-5 )},
1585+ tolerance = {torch .complex64 : (2e-5 , 3e-5 )},
15851586 ).xfail (
15861587 matcher = lambda sample : isinstance (sample .kwargs .get ("padding" ), str ),
15871588 reason = "String padding is not accepted by aten::conv2d" ,
@@ -1595,7 +1596,7 @@ def _where_input_wrangler(
15951596 "ops.aten.conv3d" , core_ops .aten_conv3d , tolerance = {torch .float32 : (3.7e-5 , 1.8e-4 )}
15961597 ),
15971598 TorchLibOpInfo (
1598- "ops.aten.conv3d" , core_ops .aten_conv3d_complex , complex = True , tolerance = {torch .float32 : (3.7e-5 , 1.8e -4 )}
1599+ "ops.aten.conv3d" , core_ops .aten_conv3d_complex , complex = True , tolerance = {torch .complex64 : (1e-4 , 5e -4 )}
15991600 ),
16001601 TorchLibOpInfo ("nn.functional.gelu" , nn_ops .aten_gelu ),
16011602 TorchLibOpInfo ("nn.functional.glu" , nn_ops .aten_glu ),
0 commit comments