Skip to content

Commit d83d50b

Browse files
titaiwangmsCopilotjustinchuby
authored
Raise version converter error when function attribute is RefAttr (#2806)
Regarding the concern from #2791 (comment), this PR specifically raises when function has RefAttr. --- This pull request adds a new validation to the version converter to ensure nodes with reference attributes (`RefAttr`) are not supported, and introduces a corresponding unit test to verify that the version converter raises an error in such cases. Validation for unsupported reference attributes: * Added a check in `visit_graph_or_function` in `onnxscript/version_converter/_version_converter.py` to raise a `VersionConverterError` if any node has a reference attribute (`RefAttr`), as these are not currently supported by the version converter. Testing and error handling: * Added a new unit test `test_version_convert_raises_on_function_node_with_ref_attribute` in `onnxscript/version_converter/_version_converter_test.py` to ensure that converting a model containing a function node with a reference attribute raises the appropriate error and includes a clear message. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 2b2618e commit d83d50b

File tree

3 files changed

+76
-6
lines changed

3 files changed

+76
-6
lines changed

onnxscript/version_converter/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,26 @@ def __init__(self, target_version: int, fallback: bool = False) -> None:
3737
super().__init__()
3838
self.target_version = target_version
3939
self.fallback = fallback
40-
self.convert_pass = ir.passes.Sequential(
41-
_ConvertVersionPass(
42-
target_version=target_version,
43-
fallback=fallback,
44-
),
40+
self._convert_pass = _ConvertVersionPass(
41+
target_version=target_version,
42+
fallback=fallback,
43+
)
44+
self._cleanup_passes = ir.passes.Sequential(
4545
common_passes.RemoveUnusedNodesPass(),
4646
common_passes.RemoveUnusedFunctionsPass(),
4747
common_passes.RemoveUnusedOpsetsPass(),
4848
)
4949

5050
def call(self, model: ir.Model) -> ir.passes.PassResult:
51-
return self.convert_pass(model)
51+
# Run the conversion pass outside of Sequential so that errors
52+
# (e.g. VersionConverterError) propagate directly without being
53+
# wrapped in PassError.
54+
result = self._convert_pass(model)
55+
cleanup_result = self._cleanup_passes(result)
56+
return ir.passes.PassResult(
57+
cleanup_result.model,
58+
result.modified or cleanup_result.modified,
59+
)
5260

5361

5462
class _ConvertVersionPass(ir.passes.InPlacePass):

onnxscript/version_converter/_version_converter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,11 @@ def visit_graph_or_function(self, graph_or_function: ir.Graph | ir.Function) ->
310310
node_version = node.version or self._default_onnx_opset
311311
if node_version is None:
312312
raise VersionConverterError(f"Node {node} has no version.")
313+
# RefAttr is not supported by adapters for now.
314+
if any(attr.is_ref() for attr in node.attributes.values()):
315+
raise VersionConverterError(
316+
f"Node '{node!r}' has ref attribute, which is not supported by version converter."
317+
)
313318
# Iterate each node from current node version -> target version
314319
# and updating node based on the correct adapter
315320
# Up-conversion [ver->ver+1] or down-conversion [ver->ver-1]

onnxscript/version_converter/_version_converter_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,63 @@ def test_metadata_is_copied_to_multiple_replacement_nodes(self):
455455
f"Node {i} ({node.op_type}) should have metadata copied",
456456
)
457457

458+
def test_version_convert_raises_on_function_node_with_ref_attribute(self):
459+
"""Test that version conversion raises when a function contains a node with a ref attribute."""
460+
# Build a function with a LeakyRelu node that uses a RefAttr for 'alpha'
461+
func_input = ir.Value(name="x")
462+
ref_attr = ir.RefAttr("alpha", "alpha", ir.AttributeType.FLOAT)
463+
func_output = ir.Value(name="result")
464+
leaky_relu_node = ir.Node(
465+
domain="",
466+
op_type="LeakyRelu",
467+
inputs=[func_input],
468+
outputs=[func_output],
469+
attributes=[ref_attr],
470+
version=18,
471+
)
472+
func_graph = ir.Graph(
473+
inputs=[func_input],
474+
outputs=[func_output],
475+
nodes=[leaky_relu_node],
476+
opset_imports={"": 18},
477+
)
478+
func_attr_param = ir.Attr("alpha", ir.AttributeType.FLOAT, 0.01)
479+
function = ir.Function(
480+
domain="pkg.custom",
481+
name="leaky_relu_func",
482+
graph=func_graph,
483+
attributes=[func_attr_param],
484+
)
485+
486+
# Build a main graph that calls the function
487+
main_input = ir.Value(name="input_x")
488+
main_output = ir.Value(name="output")
489+
call_node = ir.Node(
490+
domain="pkg.custom",
491+
op_type="leaky_relu_func",
492+
inputs=[main_input],
493+
outputs=[main_output],
494+
version=18,
495+
)
496+
main_graph = ir.Graph(
497+
inputs=[main_input],
498+
outputs=[main_output],
499+
nodes=[call_node],
500+
opset_imports={"": 18, "pkg.custom": 1},
501+
)
502+
model = ir.Model(
503+
main_graph,
504+
ir_version=8,
505+
functions=[function],
506+
)
507+
508+
target_version = 20
509+
with self.assertRaisesRegex(
510+
version_converter._version_converter.VersionConverterError, # pylint: disable=protected-access
511+
"has ref attribute, which is not supported by version converter",
512+
):
513+
version_converter.convert_version(model, target_version=target_version)
514+
458515

459516
class VersionConverter25to26Test(unittest.TestCase):
460517
@pytest.mark.xfail(strict=True, reason="Version upgrade beyond 25 not yet supported.")

0 commit comments

Comments
 (0)