Skip to content

Commit 46bb44a

Browse files
committed
Fix lint
1 parent c855a19 commit 46bb44a

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

examples/custom_op_expansion.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22
# Licensed under the MIT License.
33
"""A utility and an example showing how onnxscript functions can be used to define function expansions
44
and be used with the inliner to replace calls to the custom function with an expanded subgraph.
5-
This is useful to perform certain classes of graph surgery easily."""
5+
This is useful to perform certain classes of graph surgery easily.
6+
"""
67

78
import onnx
8-
import onnxscript
9-
from onnxscript import script, FLOAT, opset22 as op
109

10+
import onnxscript
1111

12+
script = onnxscript.script
13+
FLOAT = onnxscript.FLOAT
14+
op = onnxscript.values.opset22
1215
local = onnxscript.values.Opset("local", 1)
1316

17+
1418
# Example Model: Actual models can come from ModelBuilder or Exporter or any other source.
1519
# Models can contain calls to custom operations (from a custom domain like 'local' here or
1620
# even "com.microsoft" etc.)
@@ -24,17 +28,20 @@ def model_script(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]:
2428
Temp2 = local.CustomOp2(Temp1, alp=0.9)
2529
return Temp2
2630

31+
2732
# Define expansions for custom operations as onnxscript functions
2833
@script(opset=local)
2934
def CustomOp1(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]:
3035
Temp1 = op.Sub(X, Y)
3136
return op.Div(Temp1, X)
3237

38+
3339
@script(opset=local)
3440
def CustomOp2(X: FLOAT["N"], alp: float) -> FLOAT["N"]:
3541
Temp2 = op.Elu(X, alpha=alp)
3642
return op.Mul(Temp2, Temp2)
3743

44+
3845
# Now, we can replace the custom operations in the model with their expansions:
3946

4047
functions = [CustomOp1.to_function_proto(), CustomOp2.to_function_proto()]
@@ -45,8 +52,8 @@ def CustomOp2(X: FLOAT["N"], alp: float) -> FLOAT["N"]:
4552
print(onnx.printer.to_text(model))
4653

4754
import onnxscript.utils.replace as replace
55+
4856
updated_model = replace.replace_functions(model, functions)
4957

5058
print("\nUpdated Model after replacing custom operations with their expansions:")
5159
print(onnx.printer.to_text(updated_model))
52-

onnxscript/utils/replace.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,23 @@
33
"""A utility function to replace custom operations in a model with their expansions"""
44

55
from typing import Sequence
6+
67
import onnx
78
import onnx_ir as ir
89
import onnx_ir.passes.common as common_passes
910

10-
def replace_functions(model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto]) -> onnx.ModelProto:
11-
'''A utility function to replace custom operations in a model with their expansions:
11+
12+
def replace_functions(
13+
model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto]
14+
) -> onnx.ModelProto:
15+
"""A utility function to replace custom operations in a model with their expansions:
1216
Args:
1317
model: An ONNX ModelProto possibly containing calls to custom operations.
1418
functions: A sequence of FunctionProto defining the expansions for the custom operations.
19+
1520
Returns:
1621
An updated ModelProto with custom operations replaced by their expansions.
17-
'''
22+
"""
1823
irmodel = ir.from_proto(model)
1924
irfunctions = [ir.from_proto(func) for func in functions]
2025
model_functions = irmodel.functions

0 commit comments

Comments
 (0)