Skip to content

Commit b9a8e3d

Browse files
committed
[Autocast] Add low precision autocasting support for Resize op
Signed-off-by: Ali Boubezari <[email protected]> cleanup Signed-off-by: Ali Boubezari <[email protected]> Update modelopt/onnx/autocast/precisionconverter.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: aboubezari <[email protected]> [5571471] Fix quantization logic for residual branches with different backbones (#425) Signed-off-by: gcunhase <[email protected]> modify casts; testing Add support for tensor scales Signed-off-by: Ali Boubezari <[email protected]> Generalize & automate skipping inputs; only skip index 2 for bfloat16 Signed-off-by: Ali Boubezari <[email protected]> bugfixes Signed-off-by: Ali Boubezari <[email protected]>
1 parent bffe2ff commit b9a8e3d

File tree

2 files changed

+210
-11
lines changed

2 files changed

+210
-11
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,15 @@
4747

4848
ONNX_TYPES = [t.onnx_type for t in PRECISION_MAP.values()]
4949

50-
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Resize", "Upsample", "NonMaxSuppression", "Celu"]
50+
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"]
5151

5252
# Temporarily block these ops in low precision, as they are not supported yet
5353
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop", "LSTM"])
5454

55+
# Mapping of op types to indices of inputs that should not be converted to low precision.
56+
SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {1}}
57+
SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}}
58+
5559

5660
class PrecisionConverter:
5761
"""Precision conversion module for ONNX models.
@@ -69,7 +73,7 @@ def __init__(
6973
model: onnx.ModelProto,
7074
value_info_map: dict[str, onnx.ValueInfoProto],
7175
initializer_map: dict[str, onnx.TensorProto],
72-
node_to_init_map: dict[str, list[str]],
76+
node_to_init_map: dict[str, list[onnx.TensorProto]],
7377
keep_io_types: bool = False,
7478
low_precision_type: str = "fp16",
7579
init_conversion_max_bytes: int | None = None,
@@ -156,14 +160,23 @@ def convert(
156160
if input.type.tensor_type.elem_type == self.high_precision_type.onnx_type:
157161
input.type.tensor_type.elem_type = self.low_precision_type.onnx_type
158162

159-
cast_down_tensors, cast_up_tensors = self._get_tensors_to_cast(low_precision_nodes)
163+
cast_down_tensors, cast_up_tensors, fp32_input_to_low_precision_node = (
164+
self._get_tensors_to_cast(low_precision_nodes)
165+
)
160166
logger.debug(f"cast down (to {self.low_precision_type.str_full}): {cast_down_tensors}")
161167
logger.debug(f"cast up (to {self.high_precision_type.str_full}): {cast_up_tensors}")
162168

163169
# Add cast nodes for "cast_up" tensors
164170
for tensor_name in cast_up_tensors:
171+
exclude_consumers = low_precision_nodes
172+
if tensor_name in fp32_input_to_low_precision_node:
173+
# For the low precision nodes that take a FP32 input, we don't exclude it from
174+
# casting up so that the input can be converted to FP32 as expected.
175+
exclude_consumers = list(
176+
set(low_precision_nodes) - {fp32_input_to_low_precision_node[tensor_name].name}
177+
)
165178
self._add_cast(
166-
tensor_name, self.high_precision_type, exclude_consumers=low_precision_nodes
179+
tensor_name, self.high_precision_type, exclude_consumers=exclude_consumers
167180
)
168181

169182
# Add cast nodes for "cast_down" tensors
@@ -409,15 +422,25 @@ def _filter_unsupported_op_types(
409422
)
410423
return high_precision_nodes, low_precision_nodes
411424

412-
def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str], list[str]]:
425+
def _get_tensors_to_cast(
426+
self, low_precision_nodes: list[str]
427+
) -> tuple[list[str], list[str], dict[str, onnx.NodeProto]]:
413428
cast_to_fp16 = [] # Tensors to cast down to FP16
414429
cast_to_fp32 = [] # Tensors to cast up to FP32
430+
# Keep track of the low precision nodes that take a FP32 input.
431+
fp32_input_to_low_precision_node = {}
415432

416433
# Get tensors for FP16 nodes
417434
for node in self.model.graph.node:
418435
if node.name in low_precision_nodes:
419436
# Cast inputs to FP16 nodes down to FP16
420-
cast_to_fp16.extend(node.input)
437+
for input in node.input:
438+
if self._should_skip_low_precision_input_conversion(node, input):
439+
cast_to_fp32.append(input)
440+
fp32_input_to_low_precision_node[input] = node
441+
else:
442+
cast_to_fp16.append(input)
443+
421444
# Cast outputs from FP16 nodes up to FP32
422445
cast_to_fp32.extend(node.output)
423446

@@ -444,7 +467,7 @@ def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str
444467

445468
logger.debug(f"tensors to cast to FP16: {cast_to_fp16}")
446469
logger.debug(f"tensors to cast to FP32: {cast_to_fp32}")
447-
return cast_to_fp16, cast_to_fp32
470+
return cast_to_fp16, cast_to_fp32, fp32_input_to_low_precision_node
448471

449472
def _convert_initializers(
450473
self, low_precision_nodes: list[str], high_precision_nodes: list[str]
@@ -557,6 +580,8 @@ def convert_initializer(
557580
for node in self.model.graph.node:
558581
if node.name in low_precision_nodes:
559582
for init in self.node_to_init_map[node.name]:
583+
if self._should_skip_low_precision_input_conversion(node, init.name):
584+
continue
560585
modified |= convert_initializer(
561586
init,
562587
node,
@@ -1069,3 +1094,28 @@ def _sanitize_model(self):
10691094
)
10701095
graph_sanitizer.sanitize()
10711096
self.model = graph_sanitizer.model
1097+
1098+
def _should_skip_low_precision_input_conversion(
1099+
self, node: onnx.NodeProto, input_name: str
1100+
) -> bool:
1101+
"""Check if the input should be skipped for low precision conversion.
1102+
1103+
This is used for nodes that have inputs that MUST remain in FP32.
1104+
"""
1105+
match self.low_precision_type.str_short:
1106+
case "fp16":
1107+
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16
1108+
case "bf16":
1109+
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_BF16
1110+
case _:
1111+
raise ValueError(f"Unsupported low precision type: {self.low_precision_type}")
1112+
1113+
if node.op_type in skip_inputs_map:
1114+
# Figure out the index of the input in the node input
1115+
inputs_lst = list(node.input)
1116+
if input_name not in inputs_lst:
1117+
raise ValueError(f"Input {input_name} not found in node {node.name}.")
1118+
input_index = inputs_lst.index(input_name)
1119+
# Check if we should skip this input for low precision conversion
1120+
return input_index in skip_inputs_map[node.op_type]
1121+
return False

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 153 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type):
179179
)
180180

181181
# Test when relu node is in low precision
182-
cast_down, cast_up = converter._get_tensors_to_cast(["relu"])
182+
cast_down, cast_up, _ = converter._get_tensors_to_cast(["relu"])
183183
assert "add_output" in cast_down # Input to relu should be cast down
184184
assert "Y" in cast_up # Output of relu should be cast up
185185
if not keep_io_types:
@@ -188,7 +188,7 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type):
188188
) # Input to gemm should be cast up, because network input are converted to FP16
189189

190190
# Test when add node is in low precision
191-
cast_down, cast_up = converter._get_tensors_to_cast(["add"])
191+
cast_down, cast_up, _ = converter._get_tensors_to_cast(["add"])
192192
assert "gemm_output" in cast_down # Input to add should be cast down
193193
assert "add_init" not in cast_down # Initializer should not be in cast list
194194
assert "add_output" in cast_up # Output of add should be cast up
@@ -314,13 +314,13 @@ def test_get_tensors_to_cast_multiple_consumers(
314314
)
315315

316316
# Test when gemm2 and add1 nodes are in low precision
317-
cast_down, cast_up = converter._get_tensors_to_cast(["gemm2", "add1"])
317+
cast_down, cast_up, _ = converter._get_tensors_to_cast(["gemm2", "add1"])
318318
assert "X" in cast_down # Input to gemm2 should be cast down
319319
assert "gemm2_output" in cast_up # Output of gemm2 should be cast up
320320
assert "Y1" in cast_up # Output of add1 should be cast up
321321

322322
# Test when all nodes except gemm1 are in low precision
323-
cast_down, cast_up = converter._get_tensors_to_cast(["gemm2", "add1", "add2"])
323+
cast_down, cast_up, _ = converter._get_tensors_to_cast(["gemm2", "add1", "add2"])
324324
assert "gemm1_output" in cast_down # Input to gemm2 should be cast down
325325
assert "Y1" in cast_up # Output of add1 should be cast up
326326
assert "Y2" in cast_up # Output of add2 should be cast up
@@ -1173,3 +1173,152 @@ def test_casted_input_to_output_model(
11731173
high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
11741174
)
11751175
onnx.checker.check_model(converted_model)
1176+
1177+
@pytest.fixture
1178+
def create_model_with_resize_op():
1179+
"""
1180+
Creates an ONNX model that contains a resize operation in the middle of the computation flow.
1181+
1182+
The model structure:
1183+
X -> Add -> Resize -> Relu -> Y
1184+
"""
1185+
# Create inputs and outputs
1186+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 32, 32])
1187+
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 64, 64])
1188+
1189+
# Create initializer for add operation
1190+
add_const = np.ones((1, 3, 32, 32), dtype=np.float32)
1191+
add_init = numpy_helper.from_array(add_const, name="add_const")
1192+
1193+
# Create resize parameters
1194+
roi_empty = numpy_helper.from_array(np.array([], dtype=np.float32), name="roi")
1195+
scales = numpy_helper.from_array(
1196+
np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32), name="scales"
1197+
)
1198+
1199+
# Create nodes: Add -> Resize -> Relu
1200+
add_node = helper.make_node("Add", ["X", "add_const"], ["add_out"], name="add")
1201+
resize_node = helper.make_node(
1202+
"Resize", ["add_out", "roi", "scales"], ["resize_out"], name="resize", mode="nearest"
1203+
)
1204+
relu_node = helper.make_node("Relu", ["resize_out"], ["Y"], name="relu")
1205+
1206+
# Build the graph
1207+
graph = helper.make_graph(
1208+
[add_node, resize_node, relu_node],
1209+
"model_with_resize",
1210+
[x],
1211+
[y],
1212+
[add_init, roi_empty, scales],
1213+
)
1214+
1215+
model = helper.make_model(graph, producer_name="model_with_resize")
1216+
model.opset_import[0].version = 20
1217+
model.ir_version = 10
1218+
onnx.checker.check_model(model)
1219+
1220+
model = onnx_utils.infer_shapes(model)
1221+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1222+
1223+
return model, value_info_map, initializer_map, node_to_init_map
1224+
1225+
1226+
@pytest.fixture
1227+
def create_model_with_resize_op_tensor_scales():
1228+
"""
1229+
Creates an ONNX model that contains a resize operation where the scales
1230+
are computed from a second network input through an Add operation.
1231+
1232+
The model structure:
1233+
X -> Add -> Resize -> Relu -> Y
1234+
scales_input -> Add -> scales_tensor /
1235+
"""
1236+
# Create inputs and outputs
1237+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 32, 32])
1238+
scales_input = helper.make_tensor_value_info("scales_input", TensorProto.FLOAT, [4])
1239+
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 64, 64])
1240+
1241+
# Create initializers
1242+
add_const = np.ones((1, 3, 32, 32), dtype=np.float32)
1243+
add_init = numpy_helper.from_array(add_const, name="add_const")
1244+
1245+
# Create scales computation initializer (add small offset to input scales)
1246+
scales_offset = np.array(
1247+
[0.0, 0.0, 1.0, 1.0], dtype=np.float32
1248+
) # Will result in [1,1,2,2] when added to [1,1,1,1] input
1249+
scales_offset_init = numpy_helper.from_array(scales_offset, name="scales_offset")
1250+
1251+
# Create resize parameters
1252+
roi_empty = numpy_helper.from_array(np.array([], dtype=np.float32), name="roi")
1253+
1254+
# Create nodes
1255+
add_node = helper.make_node("Add", ["X", "add_const"], ["add_out"], name="add")
1256+
scales_add_node = helper.make_node(
1257+
"Add", ["scales_input", "scales_offset"], ["scales_tensor"], name="scales_add"
1258+
)
1259+
resize_node = helper.make_node(
1260+
"Resize", ["add_out", "roi", "scales_tensor"], ["resize_out"], name="resize", mode="nearest"
1261+
)
1262+
relu_node = helper.make_node("Relu", ["resize_out"], ["Y"], name="relu")
1263+
1264+
# Build the graph
1265+
graph = helper.make_graph(
1266+
[add_node, scales_add_node, resize_node, relu_node],
1267+
"model_with_resize_tensor_scales",
1268+
[x, scales_input], # Two network inputs
1269+
[y],
1270+
[add_init, scales_offset_init, roi_empty],
1271+
)
1272+
1273+
model = helper.make_model(graph, producer_name="model_with_resize_tensor_scales")
1274+
model.opset_import[0].version = 20
1275+
model.ir_version = 10
1276+
onnx.checker.check_model(model)
1277+
1278+
model = onnx_utils.infer_shapes(model)
1279+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1280+
1281+
return model, value_info_map, initializer_map, node_to_init_map
1282+
1283+
1284+
@pytest.mark.parametrize("keep_io_types", [True, False])
1285+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1286+
def test_resize_op_initializer_conversion(
1287+
create_model_with_resize_op, keep_io_types, low_precision_type
1288+
):
1289+
model, value_info_map, initializer_map, node_to_init_map = create_model_with_resize_op
1290+
1291+
converter = PrecisionConverter(
1292+
model,
1293+
value_info_map,
1294+
initializer_map,
1295+
node_to_init_map,
1296+
keep_io_types=keep_io_types,
1297+
low_precision_type=low_precision_type,
1298+
)
1299+
converted_model = converter.convert(
1300+
high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node]
1301+
)
1302+
onnx.checker.check_model(converted_model)
1303+
1304+
@pytest.mark.parametrize("keep_io_types", [True, False])
1305+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1306+
def test_resize_op_tensor_scales_conversion(
1307+
create_model_with_resize_op_tensor_scales, keep_io_types, low_precision_type
1308+
):
1309+
model, value_info_map, initializer_map, node_to_init_map = (
1310+
create_model_with_resize_op_tensor_scales
1311+
)
1312+
1313+
converter = PrecisionConverter(
1314+
model,
1315+
value_info_map,
1316+
initializer_map,
1317+
node_to_init_map,
1318+
keep_io_types=keep_io_types,
1319+
low_precision_type=low_precision_type,
1320+
)
1321+
converted_model = converter.convert(
1322+
high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node]
1323+
)
1324+
onnx.checker.check_model(converted_model)

0 commit comments

Comments
 (0)