Skip to content

Commit 3c3d83a

Browse files
committed
fix conv tolerances
1 parent 80ee8be commit 3c3d83a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,6 +1565,7 @@ def _where_input_wrangler(
15651565
"nn.functional.conv1d",
15661566
core_ops.aten_conv1d_complex,
15671567
complex=True,
1568+
tolerance={torch.complex64: (2e-5, 3e-5)},
15681569
).xfail(
15691570
matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str),
15701571
reason="String padding is not accepted by aten::conv1d",
@@ -1581,7 +1582,7 @@ def _where_input_wrangler(
15811582
"nn.functional.conv2d",
15821583
core_ops.aten_conv2d_complex,
15831584
complex=True,
1584-
tolerance={torch.float32: (2e-5, 3e-5)},
1585+
tolerance={torch.complex64: (2e-5, 3e-5)},
15851586
).xfail(
15861587
matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str),
15871588
reason="String padding is not accepted by aten::conv2d",
@@ -1595,7 +1596,7 @@ def _where_input_wrangler(
15951596
"ops.aten.conv3d", core_ops.aten_conv3d, tolerance={torch.float32: (3.7e-5, 1.8e-4)}
15961597
),
15971598
TorchLibOpInfo(
1598-
"ops.aten.conv3d", core_ops.aten_conv3d_complex, complex=True, tolerance={torch.float32: (3.7e-5, 1.8e-4)}
1599+
"ops.aten.conv3d", core_ops.aten_conv3d_complex, complex=True, tolerance={torch.complex64: (1e-4, 5e-4)}
15991600
),
16001601
TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu),
16011602
TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu),

0 commit comments

Comments
 (0)