Skip to content

Commit d0e83ed

Browse files
aboubezarigalagam
andauthored
[Autocast] Add low precision autocasting support for Resize op (#436)
Signed-off-by: Ali Boubezari <[email protected]> Signed-off-by: Gal Hubara Agam <[email protected]> Co-authored-by: Gal Hubara Agam <[email protected]>
1 parent 7ccaa53 commit d0e83ed

File tree

2 files changed

+212
-11
lines changed

2 files changed

+212
-11
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,15 @@
4747

4848
ONNX_TYPES = [t.onnx_type for t in PRECISION_MAP.values()]
4949

50-
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Resize", "Upsample", "NonMaxSuppression", "Celu"]
50+
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"]
5151

5252
# Temporarily block these ops in low precision, as they are not supported yet
5353
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop", "LSTM"])
5454

55+
# Mapping of op types to indices of inputs that should not be converted to low precision.
56+
SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {2}}
57+
SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}}
58+
5559

5660
class PrecisionConverter:
5761
"""Precision conversion module for ONNX models.
@@ -69,7 +73,7 @@ def __init__(
6973
model: onnx.ModelProto,
7074
value_info_map: dict[str, onnx.ValueInfoProto],
7175
initializer_map: dict[str, onnx.TensorProto],
72-
node_to_init_map: dict[str, list[str]],
76+
node_to_init_map: dict[str, list[onnx.TensorProto]],
7377
keep_io_types: bool = False,
7478
low_precision_type: str = "fp16",
7579
init_conversion_max_bytes: int | None = None,
@@ -156,14 +160,23 @@ def convert(
156160
if input.type.tensor_type.elem_type == self.high_precision_type.onnx_type:
157161
input.type.tensor_type.elem_type = self.low_precision_type.onnx_type
158162

159-
cast_down_tensors, cast_up_tensors = self._get_tensors_to_cast(low_precision_nodes)
163+
cast_down_tensors, cast_up_tensors, fp32_input_to_low_precision_node = (
164+
self._get_tensors_to_cast(low_precision_nodes)
165+
)
160166
logger.debug(f"cast down (to {self.low_precision_type.str_full}): {cast_down_tensors}")
161167
logger.debug(f"cast up (to {self.high_precision_type.str_full}): {cast_up_tensors}")
162168

163169
# Add cast nodes for "cast_up" tensors
164170
for tensor_name in cast_up_tensors:
171+
exclude_consumers = low_precision_nodes
172+
if tensor_name in fp32_input_to_low_precision_node:
173+
# For the low precision nodes that take a FP32 input, we don't exclude it from
174+
# casting up so that the input can be converted to FP32 as expected.
175+
exclude_consumers = list(
176+
set(low_precision_nodes) - {fp32_input_to_low_precision_node[tensor_name].name}
177+
)
165178
self._add_cast(
166-
tensor_name, self.high_precision_type, exclude_consumers=low_precision_nodes
179+
tensor_name, self.high_precision_type, exclude_consumers=exclude_consumers
167180
)
168181

169182
# Add cast nodes for "cast_down" tensors
@@ -409,15 +422,25 @@ def _filter_unsupported_op_types(
409422
)
410423
return high_precision_nodes, low_precision_nodes
411424

412-
def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str], list[str]]:
425+
def _get_tensors_to_cast(
426+
self, low_precision_nodes: list[str]
427+
) -> tuple[list[str], list[str], dict[str, onnx.NodeProto]]:
413428
cast_to_fp16 = [] # Tensors to cast down to FP16
414429
cast_to_fp32 = [] # Tensors to cast up to FP32
430+
# Keep track of the low precision nodes that take a FP32 input.
431+
fp32_input_to_low_precision_node = {}
415432

416433
# Get tensors for FP16 nodes
417434
for node in self.model.graph.node:
418435
if node.name in low_precision_nodes:
419436
# Cast inputs to FP16 nodes down to FP16
420-
cast_to_fp16.extend(node.input)
437+
for input in node.input:
438+
if self._should_skip_low_precision_input_conversion(node, input):
439+
cast_to_fp32.append(input)
440+
fp32_input_to_low_precision_node[input] = node
441+
else:
442+
cast_to_fp16.append(input)
443+
421444
# Cast outputs from FP16 nodes up to FP32
422445
cast_to_fp32.extend(node.output)
423446

@@ -444,7 +467,7 @@ def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str
444467

445468
logger.debug(f"tensors to cast to FP16: {cast_to_fp16}")
446469
logger.debug(f"tensors to cast to FP32: {cast_to_fp32}")
447-
return cast_to_fp16, cast_to_fp32
470+
return cast_to_fp16, cast_to_fp32, fp32_input_to_low_precision_node
448471

449472
def _convert_initializers(
450473
self, low_precision_nodes: list[str], high_precision_nodes: list[str]
@@ -557,6 +580,8 @@ def convert_initializer(
557580
for node in self.model.graph.node:
558581
if node.name in low_precision_nodes:
559582
for init in self.node_to_init_map[node.name]:
583+
if self._should_skip_low_precision_input_conversion(node, init.name):
584+
continue
560585
modified |= convert_initializer(
561586
init,
562587
node,
@@ -1069,3 +1094,28 @@ def _sanitize_model(self):
10691094
)
10701095
graph_sanitizer.sanitize()
10711096
self.model = graph_sanitizer.model
1097+
1098+
def _should_skip_low_precision_input_conversion(
1099+
self, node: onnx.NodeProto, input_name: str
1100+
) -> bool:
1101+
"""Check if the input should be skipped for low precision conversion.
1102+
1103+
This is used for nodes that have inputs that MUST remain in FP32.
1104+
"""
1105+
match self.low_precision_type.str_short:
1106+
case "fp16":
1107+
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16
1108+
case "bf16":
1109+
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_BF16
1110+
case _:
1111+
raise ValueError(f"Unsupported low precision type: {self.low_precision_type}")
1112+
1113+
if node.op_type in skip_inputs_map:
1114+
# Figure out the index of the input in the node input
1115+
inputs_lst = list(node.input)
1116+
if input_name not in inputs_lst:
1117+
raise ValueError(f"Input {input_name} not found in node {node.name}.")
1118+
input_index = inputs_lst.index(input_name)
1119+
# Check if we should skip this input for low precision conversion
1120+
return input_index in skip_inputs_map[node.op_type]
1121+
return False

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 155 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type):
179179
)
180180

181181
# Test when relu node is in low precision
182-
cast_down, cast_up = converter._get_tensors_to_cast(["relu"])
182+
cast_down, cast_up, _ = converter._get_tensors_to_cast(["relu"])
183183
assert "add_output" in cast_down # Input to relu should be cast down
184184
assert "Y" in cast_up # Output of relu should be cast up
185185
if not keep_io_types:
@@ -188,7 +188,7 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type):
188188
) # Input to gemm should be cast up, because network input are converted to FP16
189189

190190
# Test when add node is in low precision
191-
cast_down, cast_up = converter._get_tensors_to_cast(["add"])
191+
cast_down, cast_up, _ = converter._get_tensors_to_cast(["add"])
192192
assert "gemm_output" in cast_down # Input to add should be cast down
193193
assert "add_init" not in cast_down # Initializer should not be in cast list
194194
assert "add_output" in cast_up # Output of add should be cast up
@@ -314,13 +314,13 @@ def test_get_tensors_to_cast_multiple_consumers(
314314
)
315315

316316
# Test when gemm2 and add1 nodes are in low precision
317-
cast_down, cast_up = converter._get_tensors_to_cast(["gemm2", "add1"])
317+
cast_down, cast_up, _ = converter._get_tensors_to_cast(["gemm2", "add1"])
318318
assert "X" in cast_down # Input to gemm2 should be cast down
319319
assert "gemm2_output" in cast_up # Output of gemm2 should be cast up
320320
assert "Y1" in cast_up # Output of add1 should be cast up
321321

322322
# Test when all nodes except gemm1 are in low precision
323-
cast_down, cast_up = converter._get_tensors_to_cast(["gemm2", "add1", "add2"])
323+
cast_down, cast_up, _ = converter._get_tensors_to_cast(["gemm2", "add1", "add2"])
324324
assert "gemm1_output" in cast_down # Input to gemm2 should be cast down
325325
assert "Y1" in cast_up # Output of add1 should be cast up
326326
assert "Y2" in cast_up # Output of add2 should be cast up
@@ -1173,3 +1173,154 @@ def test_casted_input_to_output_model(
11731173
high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
11741174
)
11751175
onnx.checker.check_model(converted_model)
1176+
1177+
1178+
@pytest.fixture
1179+
def create_model_with_resize_op():
1180+
"""
1181+
Creates an ONNX model that contains a resize operation in the middle of the computation flow.
1182+
1183+
The model structure:
1184+
X -> Add -> Resize -> Relu -> Y
1185+
"""
1186+
# Create inputs and outputs
1187+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 32, 32])
1188+
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 64, 64])
1189+
1190+
# Create initializer for add operation
1191+
add_const = np.ones((1, 3, 32, 32), dtype=np.float32)
1192+
add_init = numpy_helper.from_array(add_const, name="add_const")
1193+
1194+
# Create resize parameters
1195+
roi_empty = numpy_helper.from_array(np.array([], dtype=np.float32), name="roi")
1196+
scales = numpy_helper.from_array(
1197+
np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32), name="scales"
1198+
)
1199+
1200+
# Create nodes: Add -> Resize -> Relu
1201+
add_node = helper.make_node("Add", ["X", "add_const"], ["add_out"], name="add")
1202+
resize_node = helper.make_node(
1203+
"Resize", ["add_out", "roi", "scales"], ["resize_out"], name="resize", mode="nearest"
1204+
)
1205+
relu_node = helper.make_node("Relu", ["resize_out"], ["Y"], name="relu")
1206+
1207+
# Build the graph
1208+
graph = helper.make_graph(
1209+
[add_node, resize_node, relu_node],
1210+
"model_with_resize",
1211+
[x],
1212+
[y],
1213+
[add_init, roi_empty, scales],
1214+
)
1215+
1216+
model = helper.make_model(graph, producer_name="model_with_resize")
1217+
model.opset_import[0].version = 20
1218+
model.ir_version = 10
1219+
onnx.checker.check_model(model)
1220+
1221+
model = onnx_utils.infer_shapes(model)
1222+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1223+
1224+
return model, value_info_map, initializer_map, node_to_init_map
1225+
1226+
1227+
@pytest.fixture
1228+
def create_model_with_resize_op_tensor_scales():
1229+
"""
1230+
Creates an ONNX model that contains a resize operation where the scales
1231+
are computed from a second network input through an Add operation.
1232+
1233+
The model structure:
1234+
X -> Add -> Resize -> Relu -> Y
1235+
scales_input -> Add -> scales_tensor /
1236+
"""
1237+
# Create inputs and outputs
1238+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 32, 32])
1239+
scales_input = helper.make_tensor_value_info("scales_input", TensorProto.FLOAT, [4])
1240+
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 64, 64])
1241+
1242+
# Create initializers
1243+
add_const = np.ones((1, 3, 32, 32), dtype=np.float32)
1244+
add_init = numpy_helper.from_array(add_const, name="add_const")
1245+
1246+
# Create scales computation initializer (add small offset to input scales)
1247+
scales_offset = np.array(
1248+
[0.0, 0.0, 1.0, 1.0], dtype=np.float32
1249+
) # Will result in [1,1,2,2] when added to [1,1,1,1] input
1250+
scales_offset_init = numpy_helper.from_array(scales_offset, name="scales_offset")
1251+
1252+
# Create resize parameters
1253+
roi_empty = numpy_helper.from_array(np.array([], dtype=np.float32), name="roi")
1254+
1255+
# Create nodes
1256+
add_node = helper.make_node("Add", ["X", "add_const"], ["add_out"], name="add")
1257+
scales_add_node = helper.make_node(
1258+
"Add", ["scales_input", "scales_offset"], ["scales_tensor"], name="scales_add"
1259+
)
1260+
resize_node = helper.make_node(
1261+
"Resize", ["add_out", "roi", "scales_tensor"], ["resize_out"], name="resize", mode="nearest"
1262+
)
1263+
relu_node = helper.make_node("Relu", ["resize_out"], ["Y"], name="relu")
1264+
1265+
# Build the graph
1266+
graph = helper.make_graph(
1267+
[add_node, scales_add_node, resize_node, relu_node],
1268+
"model_with_resize_tensor_scales",
1269+
[x, scales_input], # Two network inputs
1270+
[y],
1271+
[add_init, scales_offset_init, roi_empty],
1272+
)
1273+
1274+
model = helper.make_model(graph, producer_name="model_with_resize_tensor_scales")
1275+
model.opset_import[0].version = 20
1276+
model.ir_version = 10
1277+
onnx.checker.check_model(model)
1278+
1279+
model = onnx_utils.infer_shapes(model)
1280+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1281+
1282+
return model, value_info_map, initializer_map, node_to_init_map
1283+
1284+
1285+
@pytest.mark.parametrize("keep_io_types", [True, False])
1286+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1287+
def test_resize_op_initializer_conversion(
1288+
create_model_with_resize_op, keep_io_types, low_precision_type
1289+
):
1290+
model, value_info_map, initializer_map, node_to_init_map = create_model_with_resize_op
1291+
1292+
converter = PrecisionConverter(
1293+
model,
1294+
value_info_map,
1295+
initializer_map,
1296+
node_to_init_map,
1297+
keep_io_types=keep_io_types,
1298+
low_precision_type=low_precision_type,
1299+
)
1300+
converted_model = converter.convert(
1301+
high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node]
1302+
)
1303+
onnx.checker.check_model(converted_model)
1304+
1305+
1306+
@pytest.mark.parametrize("keep_io_types", [True, False])
1307+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1308+
def test_resize_op_tensor_scales_conversion(
1309+
create_model_with_resize_op_tensor_scales, keep_io_types, low_precision_type
1310+
):
1311+
model, value_info_map, initializer_map, node_to_init_map = (
1312+
create_model_with_resize_op_tensor_scales
1313+
)
1314+
1315+
converter = PrecisionConverter(
1316+
model,
1317+
value_info_map,
1318+
initializer_map,
1319+
node_to_init_map,
1320+
keep_io_types=keep_io_types,
1321+
low_precision_type=low_precision_type,
1322+
)
1323+
converted_model = converter.convert(
1324+
high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node]
1325+
)
1326+
onnx.checker.check_model(converted_model)

0 commit comments

Comments
 (0)