@@ -208,40 +208,114 @@ def test_version_convert_gridsample_cubic(self):
208208 self .assertEqual (model .graph .node (4 ).version , 20 )
209209 self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "cubic" )
210210
211- def test_version_convert_inline (self ):
211+ def test_version_convert_function_nodes (self ):
212+ """Test that version converter processes nodes inside model functions."""
212213 model = ir .from_onnx_text (
213214 """
214- <ir_version: 8, opset_import: [ "" : 18]>
215- agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y ) => (float[4, 257, 64, 2] output)
215+ <ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1 ]>
216+ agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
216217 {
217- shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512}>()
218- reshape_x = Reshape (input_x, shape_a)
219- shape_b = Constant<value: tensor = int64[5] {1, 4, 1024, 1024}>()
220- reshape_y = Reshape (input_x, shape_b)
221- gridsample = GridSample <mode = "bilinear"> (reshape_x, reshape_y)
222- output = foo(gridsample)
218+ output = pkg.custom.dft_func (input_x)
223219 }
224220
225- <opset_import: [ "" : 18]>
226- foo (x) => (dft) {
227- dft = DFT <axis = 2, onesided = 1> (x)
221+ <domain: "pkg.custom", opset_import: [ "" : 18]>
222+ dft_func (x) => (result) {
223+ shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
224+ reshape_x = Reshape (x, shape_a)
225+ dft = DFT <axis = 2, onesided = 1> (reshape_x)
226+ shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
227+ result = Reshape (dft, shape_c)
228228 }
229229 """
230230 )
231+ # Verify the function exists with correct initial state
232+ self .assertEqual (len (model .functions ), 1 )
233+ func = model .functions [("pkg.custom" , "dft_func" , "" )]
234+ self .assertEqual (len (func ), 5 ) # 5 nodes in the function
235+
231236 target_version = 20
232237 version_converter .convert_version (model , target_version = target_version )
233238 self .assertEqual (model .opset_imports ["" ], target_version )
234239
235- self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
236- self .assertEqual (model .graph .node (0 ).version , 20 )
237- self .assertEqual (model .graph .node (1 ).op_type , "Reshape" )
238- self .assertEqual (model .graph .node (1 ).version , 20 )
239- self .assertEqual (model .graph .node (4 ).op_type , "GridSample" )
240- self .assertEqual (model .graph .node (4 ).version , 20 )
241- self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "linear" )
242- self .assertEqual (model .graph .node (6 ).op_type , "DFT" )
243- self .assertEqual (model .graph .node (6 ).version , 20 )
244- self .assertEqual (len (model .graph .node (6 ).inputs ), 3 )
240+ # Verify that nodes inside the function were version-converted
241+ func = model .functions [("pkg.custom" , "dft_func" , "" )]
242+ self .assertEqual (func [0 ].op_type , "Constant" )
243+ self .assertEqual (func [0 ].version , 20 )
244+ self .assertEqual (func [1 ].op_type , "Reshape" )
245+ self .assertEqual (func [1 ].version , 20 )
246+ # After DFT adapter, a new Constant node is inserted for dft_length
247+ self .assertEqual (func [2 ].op_type , "Constant" )
248+ self .assertEqual (func [2 ].version , 20 )
249+ self .assertEqual (func [3 ].op_type , "DFT" )
250+ self .assertEqual (func [3 ].version , 20 )
251+ self .assertEqual (len (func [3 ].inputs ), 3 ) # DFT 19->20 adds dft_length input
252+
253+ def test_version_convert_function_with_control_flow_subgraph (self ):
254+ """Test that version converter processes subgraphs inside control flow nodes in functions."""
255+ model = ir .from_onnx_text (
256+ """
257+ <ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1]>
258+ agraph (float[4, 512, 512] input_x, bool cond) => (float[4, 257, 64, 2] output)
259+ {
260+ output = pkg.custom.conditional_dft (input_x, cond)
261+ }
262+
263+ <domain: "pkg.custom", opset_import: [ "" : 18]>
264+ conditional_dft (x, cond) => (result) {
265+ result = If (cond) <then_branch: graph = then_graph () => (out) {
266+ shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
267+ reshape_x = Reshape (x, shape_a)
268+ dft = DFT <axis = 2, onesided = 1> (reshape_x)
269+ shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
270+ out = Reshape (dft, shape_c)
271+ }, else_branch: graph = else_graph () => (out) {
272+ shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
273+ out = Reshape (x, shape_c)
274+ }>
275+ }
276+ """
277+ )
278+ # Verify the function exists with correct initial state
279+ self .assertEqual (len (model .functions ), 1 )
280+ func = model .functions [("pkg.custom" , "conditional_dft" , "" )]
281+ self .assertEqual (len (func ), 1 ) # 1 node (If) in the function
282+
283+ # Verify the If node has subgraphs
284+ if_node = func [0 ]
285+ self .assertEqual (if_node .op_type , "If" )
286+ then_branch = if_node .attributes ["then_branch" ].as_graph ()
287+ else_branch = if_node .attributes ["else_branch" ].as_graph ()
288+ self .assertEqual (len (then_branch ), 5 ) # 5 nodes in then_branch
289+ self .assertEqual (len (else_branch ), 2 ) # 2 nodes in else_branch
290+
291+ target_version = 20
292+ # Use internal API to test function version conversion without inlining
293+ version_converter .convert_version (model , target_version = target_version )
294+ self .assertEqual (model .opset_imports ["" ], target_version )
295+
296+ # Verify nodes inside the function's If node subgraphs were version-converted
297+ func = model .functions [("pkg.custom" , "conditional_dft" , "" )]
298+ if_node = func [0 ]
299+ self .assertEqual (if_node .op_type , "If" )
300+ self .assertEqual (if_node .version , 20 )
301+
302+ # Check then_branch subgraph nodes
303+ then_branch = if_node .attributes ["then_branch" ].as_graph ()
304+ # After DFT adapter, a new Constant node is inserted for dft_length
305+ self .assertEqual (len (then_branch ), 6 ) # 5 + 1 new Constant for DFT
306+ dft_node = None
307+ for node in then_branch :
308+ self .assertEqual (node .version , 20 )
309+ if node .op_type == "DFT" :
310+ dft_node = node
311+ self .assertIsNotNone (dft_node )
312+ self .assertEqual (len (dft_node .inputs ), 3 ) # DFT 19->20 adds dft_length input
313+
314+ # Check else_branch subgraph nodes
315+ else_branch = if_node .attributes ["else_branch" ].as_graph ()
316+ self .assertEqual (len (else_branch ), 2 )
317+ for node in else_branch :
318+ self .assertEqual (node .version , 20 )
245319
246320
247321class VersionConverter20to21Test (unittest .TestCase ):
0 commit comments