Skip to content

Commit c855a19

Browse files
committed
Utility and example for custom op expansion
1 parent cba1325 commit c855a19

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

examples/custom_op_expansion.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""A utility and an example showing how onnxscript functions can be used to define function expansions
4+
and be used with the inliner to replace calls to the custom function with an expanded subgraph.
5+
This is useful to perform certain classes of graph surgery easily."""
6+
7+
import onnx
8+
import onnxscript
9+
from onnxscript import script, FLOAT, opset22 as op
10+
11+
12+
local = onnxscript.values.Opset("local", 1)
13+
14+
# Example Model: Actual models can come from ModelBuilder or Exporter or any other source.
15+
# Models can contain calls to custom operations (from a custom domain like 'local' here or
16+
# even "com.microsoft" etc.)
17+
@script()
18+
def model_script(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]:
19+
DoubleX = op.Add(X, X)
20+
YSquare = op.Mul(Y, Y)
21+
# Example call to a custom operation
22+
Temp1 = local.CustomOp1(DoubleX, YSquare)
23+
# Another call to a custom operation with an attribute
24+
Temp2 = local.CustomOp2(Temp1, alp=0.9)
25+
return Temp2
26+
27+
# Define expansions for custom operations as onnxscript functions
28+
@script(opset=local)
29+
def CustomOp1(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]:
30+
Temp1 = op.Sub(X, Y)
31+
return op.Div(Temp1, X)
32+
33+
@script(opset=local)
34+
def CustomOp2(X: FLOAT["N"], alp: float) -> FLOAT["N"]:
35+
Temp2 = op.Elu(X, alpha=alp)
36+
return op.Mul(Temp2, Temp2)
37+
38+
# Now, we can replace the custom operations in the model with their expansions:
39+
40+
functions = [CustomOp1.to_function_proto(), CustomOp2.to_function_proto()]
41+
42+
model = model_script.to_model_proto()
43+
44+
print("Original Model with custom operations:")
45+
print(onnx.printer.to_text(model))
46+
47+
import onnxscript.utils.replace as replace
48+
updated_model = replace.replace_functions(model, functions)
49+
50+
print("\nUpdated Model after replacing custom operations with their expansions:")
51+
print(onnx.printer.to_text(updated_model))
52+

onnxscript/utils/replace.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""A utility function to replace custom operations in a model with their expansions"""
4+
5+
from typing import Sequence
6+
import onnx
7+
import onnx_ir as ir
8+
import onnx_ir.passes.common as common_passes
9+
10+
def replace_functions(model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto]) -> onnx.ModelProto:
11+
'''A utility function to replace custom operations in a model with their expansions:
12+
Args:
13+
model: An ONNX ModelProto possibly containing calls to custom operations.
14+
functions: A sequence of FunctionProto defining the expansions for the custom operations.
15+
Returns:
16+
An updated ModelProto with custom operations replaced by their expansions.
17+
'''
18+
irmodel = ir.from_proto(model)
19+
irfunctions = [ir.from_proto(func) for func in functions]
20+
model_functions = irmodel.functions
21+
if len(model_functions) != 0:
22+
# Since we use inlining, check that there are no model-local functions.
23+
raise ValueError("Input model cannot have model-local functions.")
24+
for func in irfunctions:
25+
model_functions[func.identifier()] = func
26+
27+
# TODO (rama): Ideally, we should provide users more control over renaming strategy for inlined values.
28+
common_passes.InlinePass()(irmodel)
29+
common_passes.RemoveUnusedOpsetsPass()(irmodel)
30+
return ir.to_proto(irmodel)

0 commit comments

Comments
 (0)