Skip to content

Commit 74a5f34

Browse files
authored
Remove function value error in version converter (#2791)
Fix #2790 This pull request makes a targeted change to the version converter in `onnxscript`. The main update removes the restriction that prevented models containing functions from being processed by the version conversion pass. Version conversion support update: * Removed the check that raised an error when the input model contained functions, allowing the version conversion pass to process such models without requiring prior inlining. (`onnxscript/version_converter/__init__.py`)
1 parent 0206a98 commit 74a5f34

File tree

3 files changed

+106
-37
lines changed

3 files changed

+106
-37
lines changed

onnxscript/version_converter/__init__.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def __init__(self, target_version: int, fallback: bool = False) -> None:
3838
self.target_version = target_version
3939
self.fallback = fallback
4040
self.convert_pass = ir.passes.Sequential(
41-
common_passes.InlinePass(),
42-
_ConvertVersionPassRequiresInline(
41+
_ConvertVersionPass(
4342
target_version=target_version,
4443
fallback=fallback,
4544
),
@@ -52,7 +51,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
5251
return self.convert_pass(model)
5352

5453

55-
class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass):
54+
class _ConvertVersionPass(ir.passes.InPlacePass):
5655
"""Convert the model to the specified ONNX opset version.
5756
5857
This pass leverages the onnxscript version converter to convert the model. If
@@ -73,12 +72,6 @@ def __init__(self, target_version: int, fallback: bool) -> None:
7372
self.fallback = fallback
7473

7574
def call(self, model: ir.Model) -> ir.passes.PassResult:
76-
if model.functions:
77-
raise ValueError(
78-
"The model contains functions. The version conversion pass does not support "
79-
"functions. Please use `common_passes.InlinePass` to inline the "
80-
f"functions before applying this pass ({self.__class__.__name__})."
81-
)
8275
if "" in model.graph.opset_imports:
8376
onnx_opset_version = model.graph.opset_imports[""]
8477
if onnx_opset_version == self.target_version:

onnxscript/version_converter/_version_converter.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,10 @@ def visit_attribute(self, attr: ir.Attr) -> None:
274274
if attr.is_ref():
275275
return
276276
if attr.type == ir.AttributeType.GRAPH:
277-
self.visit_graph(attr.as_graph())
277+
self.visit_graph_or_function(attr.as_graph())
278278
elif attr.type == ir.AttributeType.GRAPHS:
279279
for graph in attr.as_graphs():
280-
self.visit_graph(graph)
280+
self.visit_graph_or_function(graph)
281281

282282
def visit_node(
283283
self,
@@ -303,8 +303,8 @@ def visit_node(
303303
self._default_metadata_merger.copy_merged_metadata([node], replacement.new_nodes)
304304
self.replace_node(node, replacement, root)
305305

306-
def visit_graph(self, graph: ir.Graph) -> None:
307-
for node in graph:
306+
def visit_graph_or_function(self, graph_or_function: ir.Graph | ir.Function) -> None:
307+
for node in graph_or_function:
308308
if node.domain != "":
309309
continue
310310
node_version = node.version or self._default_onnx_opset
@@ -321,7 +321,7 @@ def visit_graph(self, graph: ir.Graph) -> None:
321321
)
322322
for from_version in range(node_version, self._target_version):
323323
try:
324-
self.visit_node(node, graph, from_version, up_conversion=True)
324+
self.visit_node(node, graph_or_function, from_version, up_conversion=True)
325325
except VersionConverterError as e:
326326
logger.warning(
327327
"Skipping version conversion for node %s due to exception: %s",
@@ -331,7 +331,9 @@ def visit_graph(self, graph: ir.Graph) -> None:
331331

332332
def visit_model(self, model: ir.Model) -> None:
333333
self._default_onnx_opset = _get_onnx_opset_version(model)
334-
self.visit_graph(model.graph)
334+
self.visit_graph_or_function(model.graph)
335+
for function in model.functions.values():
336+
self.visit_graph_or_function(function)
335337
_set_onnx_opset_version(model, self._target_version)
336338

337339

onnxscript/version_converter/_version_converter_test.py

Lines changed: 96 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -208,40 +208,114 @@ def test_version_convert_gridsample_cubic(self):
208208
self.assertEqual(model.graph.node(4).version, 20)
209209
self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic")
210210

211-
def test_version_convert_inline(self):
211+
def test_version_convert_function_nodes(self):
212+
"""Test that version converter processes nodes inside model functions."""
212213
model = ir.from_onnx_text(
213214
"""
214-
<ir_version: 8, opset_import: [ "" : 18]>
215-
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output)
215+
<ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1]>
216+
agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
216217
{
217-
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512}>()
218-
reshape_x = Reshape (input_x, shape_a)
219-
shape_b = Constant<value: tensor = int64[5] {1, 4, 1024, 1024}>()
220-
reshape_y = Reshape (input_x, shape_b)
221-
gridsample = GridSample <mode = "bilinear"> (reshape_x, reshape_y)
222-
output = foo(gridsample)
218+
output = pkg.custom.dft_func (input_x)
223219
}
224220
225-
<opset_import: [ "" : 18]>
226-
foo (x) => (dft) {
227-
dft = DFT <axis = 2, onesided = 1> (x)
221+
<domain: "pkg.custom", opset_import: [ "" : 18]>
222+
dft_func (x) => (result) {
223+
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
224+
reshape_x = Reshape (x, shape_a)
225+
dft = DFT <axis = 2, onesided = 1> (reshape_x)
226+
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
227+
result = Reshape (dft, shape_c)
228228
}
229229
"""
230230
)
231+
# Verify the function exists with correct initial state
232+
self.assertEqual(len(model.functions), 1)
233+
func = model.functions[("pkg.custom", "dft_func", "")]
234+
self.assertEqual(len(func), 5) # 5 nodes in the function
235+
231236
target_version = 20
232237
version_converter.convert_version(model, target_version=target_version)
233238
self.assertEqual(model.opset_imports[""], target_version)
234239

235-
self.assertEqual(model.graph.node(0).op_type, "Constant")
236-
self.assertEqual(model.graph.node(0).version, 20)
237-
self.assertEqual(model.graph.node(1).op_type, "Reshape")
238-
self.assertEqual(model.graph.node(1).version, 20)
239-
self.assertEqual(model.graph.node(4).op_type, "GridSample")
240-
self.assertEqual(model.graph.node(4).version, 20)
241-
self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear")
242-
self.assertEqual(model.graph.node(6).op_type, "DFT")
243-
self.assertEqual(model.graph.node(6).version, 20)
244-
self.assertEqual(len(model.graph.node(6).inputs), 3)
240+
# Verify that nodes inside the function were version-converted
241+
func = model.functions[("pkg.custom", "dft_func", "")]
242+
self.assertEqual(func[0].op_type, "Constant")
243+
self.assertEqual(func[0].version, 20)
244+
self.assertEqual(func[1].op_type, "Reshape")
245+
self.assertEqual(func[1].version, 20)
246+
# After DFT adapter, a new Constant node is inserted for dft_length
247+
self.assertEqual(func[2].op_type, "Constant")
248+
self.assertEqual(func[2].version, 20)
249+
self.assertEqual(func[3].op_type, "DFT")
250+
self.assertEqual(func[3].version, 20)
251+
self.assertEqual(len(func[3].inputs), 3) # DFT 19->20 adds dft_length input
252+
253+
def test_version_convert_function_with_control_flow_subgraph(self):
254+
"""Test that version converter processes subgraphs inside control flow nodes in functions."""
255+
model = ir.from_onnx_text(
256+
"""
257+
<ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1]>
258+
agraph (float[4, 512, 512] input_x, bool cond) => (float[4, 257, 64, 2] output)
259+
{
260+
output = pkg.custom.conditional_dft (input_x, cond)
261+
}
262+
263+
<domain: "pkg.custom", opset_import: [ "" : 18]>
264+
conditional_dft (x, cond) => (result) {
265+
result = If (cond) <then_branch: graph = then_graph () => (out) {
266+
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
267+
reshape_x = Reshape (x, shape_a)
268+
dft = DFT <axis = 2, onesided = 1> (reshape_x)
269+
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
270+
out = Reshape (dft, shape_c)
271+
}, else_branch: graph = else_graph () => (out) {
272+
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
273+
out = Reshape (x, shape_c)
274+
}>
275+
}
276+
"""
277+
)
278+
# Verify the function exists with correct initial state
279+
self.assertEqual(len(model.functions), 1)
280+
func = model.functions[("pkg.custom", "conditional_dft", "")]
281+
self.assertEqual(len(func), 1) # 1 node (If) in the function
282+
283+
# Verify the If node has subgraphs
284+
if_node = func[0]
285+
self.assertEqual(if_node.op_type, "If")
286+
then_branch = if_node.attributes["then_branch"].as_graph()
287+
else_branch = if_node.attributes["else_branch"].as_graph()
288+
self.assertEqual(len(then_branch), 5) # 5 nodes in then_branch
289+
self.assertEqual(len(else_branch), 2) # 2 nodes in else_branch
290+
291+
target_version = 20
292+
# Use internal API to test function version conversion without inlining
293+
version_converter.convert_version(model, target_version=target_version)
294+
self.assertEqual(model.opset_imports[""], target_version)
295+
296+
# Verify nodes inside the function's If node subgraphs were version-converted
297+
func = model.functions[("pkg.custom", "conditional_dft", "")]
298+
if_node = func[0]
299+
self.assertEqual(if_node.op_type, "If")
300+
self.assertEqual(if_node.version, 20)
301+
302+
# Check then_branch subgraph nodes
303+
then_branch = if_node.attributes["then_branch"].as_graph()
304+
# After DFT adapter, a new Constant node is inserted for dft_length
305+
self.assertEqual(len(then_branch), 6) # 5 + 1 new Constant for DFT
306+
dft_node = None
307+
for node in then_branch:
308+
self.assertEqual(node.version, 20)
309+
if node.op_type == "DFT":
310+
dft_node = node
311+
self.assertIsNotNone(dft_node)
312+
self.assertEqual(len(dft_node.inputs), 3) # DFT 19->20 adds dft_length input
313+
314+
# Check else_branch subgraph nodes
315+
else_branch = if_node.attributes["else_branch"].as_graph()
316+
self.assertEqual(len(else_branch), 2)
317+
for node in else_branch:
318+
self.assertEqual(node.version, 20)
245319

246320

247321
class VersionConverter20to21Test(unittest.TestCase):

0 commit comments

Comments
 (0)