Skip to content

Commit 4ed1713

Browse files
committed
nit: copy_ops -> get_copy_ops
Signed-off-by: gcunhase <[email protected]>
1 parent cb36cb5 commit 4ed1713

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

modelopt/onnx/op_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def is_fusible_scaling_op(op_type: str):
9696
]
9797

9898

99-
def copy_ops():
99+
def get_copy_ops():
100100
"""Returns list of copy operators."""
101101
return [
102102
"Flatten",
@@ -120,7 +120,7 @@ def copy_ops():
120120

121121
def is_copy_op(op_type: str):
122122
"""Returns whether the given op is a copy operator or not."""
123-
return op_type in copy_ops()
123+
return op_type in get_copy_ops()
124124

125125

126126
def is_linear_op(op_type: str):

modelopt/onnx/quantization/graph_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from onnxruntime.quantization.calibrate import CalibrationDataReader
3030

3131
from modelopt.onnx.logging_config import logger
32-
from modelopt.onnx.op_types import copy_ops, is_copy_op, is_linear_op
32+
from modelopt.onnx.op_types import get_copy_ops, is_copy_op, is_linear_op
3333
from modelopt.onnx.quantization.ort_utils import create_inference_session
3434
from modelopt.onnx.utils import (
3535
find_lowest_common_ancestor,
@@ -207,7 +207,7 @@ def _get_backbone(root: Node):
207207
["Mul", "Sigmoid", "BatchNormalization", conv_type],
208208
]
209209
for idx, path_type in enumerate(fusible_linear_path_types):
210-
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=copy_ops()):
210+
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=get_copy_ops()):
211211
return _get_backbone(node)
212212

213213
return None

0 commit comments

Comments
 (0)