Skip to content

Commit 9dbf685

Browse files
authored
Provide inplace replacement util (#2708)
Redo #2703, thanks to Titai. Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 6247ac1 commit 9dbf685

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

onnxscript/utils/replace.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,17 @@
99
import onnx_ir.passes.common as common_passes
1010

1111

12-
def replace_functions(
13-
model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto]
14-
) -> onnx.ModelProto:
12+
def replace_functions_inplace(irmodel: ir.Model, irfunctions: Sequence[ir.Function]) -> None:
1513
"""A utility function to replace custom operations in a model with their expansions:
14+
15+
The model is updated in-place.
16+
1617
Args:
17-
model: An ONNX ModelProto possibly containing calls to custom operations.
18-
functions: A sequence of FunctionProto defining the expansions for the custom operations.
18+
irmodel: An ONNX model possibly containing calls to custom operations.
19+
irfunctions: A sequence of functions defining the expansions for the custom operations.
20+
1921
20-
Returns:
21-
An updated ModelProto with custom operations replaced by their expansions.
2222
"""
23-
irmodel = ir.from_proto(model)
24-
irfunctions = [ir.from_proto(func) for func in functions]
2523
model_functions = irmodel.functions
2624
if len(model_functions) != 0:
2725
# Since we use inlining, check that there are no model-local functions.
@@ -32,4 +30,20 @@ def replace_functions(
3230
# TODO (rama): Ideally, we should provide users more control over renaming strategy for inlined values.
3331
common_passes.InlinePass()(irmodel)
3432
common_passes.RemoveUnusedOpsetsPass()(irmodel)
33+
34+
35+
def replace_functions(
36+
model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto]
37+
) -> onnx.ModelProto:
38+
"""A utility function to replace custom operations in a model with their expansions:
39+
Args:
40+
model: An ONNX ModelProto possibly containing calls to custom operations.
41+
functions: A sequence of FunctionProto defining the expansions for the custom operations.
42+
43+
Returns:
44+
An updated ModelProto with custom operations replaced by their expansions.
45+
"""
46+
irmodel = ir.from_proto(model)
47+
irfunctions = [ir.from_proto(func) for func in functions]
48+
replace_functions_inplace(irmodel, irfunctions)
3549
return ir.to_proto(irmodel)

0 commit comments

Comments
 (0)