Skip to content

Commit 5ce8ae9

Browse files
committed
move out version converter logic from sequential
1 parent a5b0d86 commit 5ce8ae9

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

onnxscript/version_converter/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,26 @@ def __init__(self, target_version: int, fallback: bool = False) -> None:
3737
super().__init__()
3838
self.target_version = target_version
3939
self.fallback = fallback
40-
self.convert_pass = ir.passes.Sequential(
41-
_ConvertVersionPass(
42-
target_version=target_version,
43-
fallback=fallback,
44-
),
40+
self._convert_pass = _ConvertVersionPass(
41+
target_version=target_version,
42+
fallback=fallback,
43+
)
44+
self._cleanup_passes = ir.passes.Sequential(
4545
common_passes.RemoveUnusedNodesPass(),
4646
common_passes.RemoveUnusedFunctionsPass(),
4747
common_passes.RemoveUnusedOpsetsPass(),
4848
)
4949

5050
def call(self, model: ir.Model) -> ir.passes.PassResult:
51-
return self.convert_pass(model)
51+
# Run the conversion pass outside of Sequential so that errors
52+
# (e.g. VersionConverterError) propagate directly without being
53+
# wrapped in PassError.
54+
result = self._convert_pass(model)
55+
cleanup_result = self._cleanup_passes(result)
56+
return ir.passes.PassResult(
57+
cleanup_result.model,
58+
result.modified or cleanup_result.modified,
59+
)
5260

5361

5462
class _ConvertVersionPass(ir.passes.InPlacePass):

onnxscript/version_converter/_version_converter_test.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -507,19 +507,12 @@ def test_version_convert_raises_on_function_node_with_ref_attribute(self):
507507

508508
target_version = 20
509509
with self.assertRaises(
510-
(
511-
version_converter._version_converter.VersionConverterError, # pylint: disable=protected-access
512-
ir.passes.PassError,
513-
)
510+
version_converter._version_converter.VersionConverterError, # pylint: disable=protected-access
514511
) as ctx:
515512
version_converter.convert_version(model, target_version=target_version)
516-
# Check the error message, unwrapping PassError if needed
517-
error = ctx.exception
518-
if isinstance(error, ir.passes.PassError) and error.__cause__ is not None:
519-
error = error.__cause__
520-
self.assertIn(
513+
self.assertRegex(
514+
str(ctx.exception),
521515
"has ref attribute, which is not supported by version converter",
522-
str(error),
523516
)
524517

525518

0 commit comments

Comments
 (0)