Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,15 @@

ONNX_TYPES = [t.onnx_type for t in PRECISION_MAP.values()]

OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Resize", "Upsample", "NonMaxSuppression", "Celu"]
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"]

# Temporarily block these ops in low precision, as they are not supported yet
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop", "LSTM"])

# Mapping of op types to indices of inputs that should not be converted to low precision.
SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {2}}
SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}}


class PrecisionConverter:
"""Precision conversion module for ONNX models.
Expand All @@ -69,7 +73,7 @@ def __init__(
model: onnx.ModelProto,
value_info_map: dict[str, onnx.ValueInfoProto],
initializer_map: dict[str, onnx.TensorProto],
node_to_init_map: dict[str, list[str]],
node_to_init_map: dict[str, list[onnx.TensorProto]],
keep_io_types: bool = False,
low_precision_type: str = "fp16",
init_conversion_max_bytes: int | None = None,
Expand Down Expand Up @@ -156,14 +160,23 @@ def convert(
if input.type.tensor_type.elem_type == self.high_precision_type.onnx_type:
input.type.tensor_type.elem_type = self.low_precision_type.onnx_type

cast_down_tensors, cast_up_tensors = self._get_tensors_to_cast(low_precision_nodes)
cast_down_tensors, cast_up_tensors, fp32_input_to_low_precision_node = (
self._get_tensors_to_cast(low_precision_nodes)
)
logger.debug(f"cast down (to {self.low_precision_type.str_full}): {cast_down_tensors}")
logger.debug(f"cast up (to {self.high_precision_type.str_full}): {cast_up_tensors}")

# Add cast nodes for "cast_up" tensors
for tensor_name in cast_up_tensors:
exclude_consumers = low_precision_nodes
if tensor_name in fp32_input_to_low_precision_node:
# For the low precision nodes that take a FP32 input, we don't exclude it from
# casting up so that the input can be converted to FP32 as expected.
exclude_consumers = list(
set(low_precision_nodes) - {fp32_input_to_low_precision_node[tensor_name].name}
)
self._add_cast(
tensor_name, self.high_precision_type, exclude_consumers=low_precision_nodes
tensor_name, self.high_precision_type, exclude_consumers=exclude_consumers
)

# Add cast nodes for "cast_down" tensors
Expand Down Expand Up @@ -409,15 +422,25 @@ def _filter_unsupported_op_types(
)
return high_precision_nodes, low_precision_nodes

def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str], list[str]]:
def _get_tensors_to_cast(
self, low_precision_nodes: list[str]
) -> tuple[list[str], list[str], dict[str, onnx.NodeProto]]:
cast_to_fp16 = [] # Tensors to cast down to FP16
cast_to_fp32 = [] # Tensors to cast up to FP32
# Keep track of the low precision nodes that take a FP32 input.
fp32_input_to_low_precision_node = {}

# Get tensors for FP16 nodes
for node in self.model.graph.node:
if node.name in low_precision_nodes:
# Cast inputs to FP16 nodes down to FP16
cast_to_fp16.extend(node.input)
for input in node.input:
if self._should_skip_low_precision_input_conversion(node, input):
cast_to_fp32.append(input)
fp32_input_to_low_precision_node[input] = node
else:
cast_to_fp16.append(input)

# Cast outputs from FP16 nodes up to FP32
cast_to_fp32.extend(node.output)

Expand All @@ -444,7 +467,7 @@ def _get_tensors_to_cast(self, low_precision_nodes: list[str]) -> tuple[list[str

logger.debug(f"tensors to cast to FP16: {cast_to_fp16}")
logger.debug(f"tensors to cast to FP32: {cast_to_fp32}")
return cast_to_fp16, cast_to_fp32
return cast_to_fp16, cast_to_fp32, fp32_input_to_low_precision_node

def _convert_initializers(
self, low_precision_nodes: list[str], high_precision_nodes: list[str]
Expand Down Expand Up @@ -557,6 +580,8 @@ def convert_initializer(
for node in self.model.graph.node:
if node.name in low_precision_nodes:
for init in self.node_to_init_map[node.name]:
if self._should_skip_low_precision_input_conversion(node, init.name):
continue
modified |= convert_initializer(
init,
node,
Expand Down Expand Up @@ -1069,3 +1094,28 @@ def _sanitize_model(self):
)
graph_sanitizer.sanitize()
self.model = graph_sanitizer.model

def _should_skip_low_precision_input_conversion(
self, node: onnx.NodeProto, input_name: str
) -> bool:
"""Check if the input should be skipped for low precision conversion.

This is used for nodes that have inputs that MUST remain in FP32.
"""
match self.low_precision_type.str_short:
case "fp16":
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16
case "bf16":
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_BF16
case _:
raise ValueError(f"Unsupported low precision type: {self.low_precision_type}")

if node.op_type in skip_inputs_map:
# Figure out the index of the input in the node input
inputs_lst = list(node.input)
if input_name not in inputs_lst:
raise ValueError(f"Input {input_name} not found in node {node.name}.")
input_index = inputs_lst.index(input_name)
# Check if we should skip this input for low precision conversion
return input_index in skip_inputs_map[node.op_type]
return False
159 changes: 155 additions & 4 deletions tests/unit/onnx/autocast/test_precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type):
)

# Test when relu node is in low precision
cast_down, cast_up = converter._get_tensors_to_cast(["relu"])
cast_down, cast_up, _ = converter._get_tensors_to_cast(["relu"])
assert "add_output" in cast_down # Input to relu should be cast down
assert "Y" in cast_up # Output of relu should be cast up
if not keep_io_types:
Expand All @@ -188,7 +188,7 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type):
) # Input to gemm should be cast up, because network input are converted to FP16

# Test when add node is in low precision
cast_down, cast_up = converter._get_tensors_to_cast(["add"])
cast_down, cast_up, _ = converter._get_tensors_to_cast(["add"])
assert "gemm_output" in cast_down # Input to add should be cast down
assert "add_init" not in cast_down # Initializer should not be in cast list
assert "add_output" in cast_up # Output of add should be cast up
Expand Down Expand Up @@ -314,13 +314,13 @@ def test_get_tensors_to_cast_multiple_consumers(
)

# Test when gemm2 and add1 nodes are in low precision
cast_down, cast_up = converter._get_tensors_to_cast(["gemm2", "add1"])
cast_down, cast_up, _ = converter._get_tensors_to_cast(["gemm2", "add1"])
assert "X" in cast_down # Input to gemm2 should be cast down
assert "gemm2_output" in cast_up # Output of gemm2 should be cast up
assert "Y1" in cast_up # Output of add1 should be cast up

# Test when all nodes except gemm1 are in low precision
cast_down, cast_up = converter._get_tensors_to_cast(["gemm2", "add1", "add2"])
cast_down, cast_up, _ = converter._get_tensors_to_cast(["gemm2", "add1", "add2"])
assert "gemm1_output" in cast_down # Input to gemm2 should be cast down
assert "Y1" in cast_up # Output of add1 should be cast up
assert "Y2" in cast_up # Output of add2 should be cast up
Expand Down Expand Up @@ -1173,3 +1173,154 @@ def test_casted_input_to_output_model(
high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
)
onnx.checker.check_model(converted_model)


@pytest.fixture
def create_model_with_resize_op():
"""
Creates an ONNX model that contains a resize operation in the middle of the computation flow.

The model structure:
X -> Add -> Resize -> Relu -> Y
"""
# Create inputs and outputs
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 32, 32])
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 64, 64])

# Create initializer for add operation
add_const = np.ones((1, 3, 32, 32), dtype=np.float32)
add_init = numpy_helper.from_array(add_const, name="add_const")

# Create resize parameters
roi_empty = numpy_helper.from_array(np.array([], dtype=np.float32), name="roi")
scales = numpy_helper.from_array(
np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32), name="scales"
)

# Create nodes: Add -> Resize -> Relu
add_node = helper.make_node("Add", ["X", "add_const"], ["add_out"], name="add")
resize_node = helper.make_node(
"Resize", ["add_out", "roi", "scales"], ["resize_out"], name="resize", mode="nearest"
)
relu_node = helper.make_node("Relu", ["resize_out"], ["Y"], name="relu")

# Build the graph
graph = helper.make_graph(
[add_node, resize_node, relu_node],
"model_with_resize",
[x],
[y],
[add_init, roi_empty, scales],
)

model = helper.make_model(graph, producer_name="model_with_resize")
model.opset_import[0].version = 20
model.ir_version = 10
onnx.checker.check_model(model)

model = onnx_utils.infer_shapes(model)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)

return model, value_info_map, initializer_map, node_to_init_map


@pytest.fixture
def create_model_with_resize_op_tensor_scales():
"""
Creates an ONNX model that contains a resize operation where the scales
are computed from a second network input through an Add operation.

The model structure:
X -> Add -> Resize -> Relu -> Y
scales_input -> Add -> scales_tensor /
"""
# Create inputs and outputs
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 32, 32])
scales_input = helper.make_tensor_value_info("scales_input", TensorProto.FLOAT, [4])
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 64, 64])

# Create initializers
add_const = np.ones((1, 3, 32, 32), dtype=np.float32)
add_init = numpy_helper.from_array(add_const, name="add_const")

# Create scales computation initializer (add small offset to input scales)
scales_offset = np.array(
[0.0, 0.0, 1.0, 1.0], dtype=np.float32
) # Will result in [1,1,2,2] when added to [1,1,1,1] input
scales_offset_init = numpy_helper.from_array(scales_offset, name="scales_offset")

# Create resize parameters
roi_empty = numpy_helper.from_array(np.array([], dtype=np.float32), name="roi")

# Create nodes
add_node = helper.make_node("Add", ["X", "add_const"], ["add_out"], name="add")
scales_add_node = helper.make_node(
"Add", ["scales_input", "scales_offset"], ["scales_tensor"], name="scales_add"
)
resize_node = helper.make_node(
"Resize", ["add_out", "roi", "scales_tensor"], ["resize_out"], name="resize", mode="nearest"
)
relu_node = helper.make_node("Relu", ["resize_out"], ["Y"], name="relu")

# Build the graph
graph = helper.make_graph(
[add_node, scales_add_node, resize_node, relu_node],
"model_with_resize_tensor_scales",
[x, scales_input], # Two network inputs
[y],
[add_init, scales_offset_init, roi_empty],
)

model = helper.make_model(graph, producer_name="model_with_resize_tensor_scales")
model.opset_import[0].version = 20
model.ir_version = 10
onnx.checker.check_model(model)

model = onnx_utils.infer_shapes(model)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)

return model, value_info_map, initializer_map, node_to_init_map


@pytest.mark.parametrize("keep_io_types", [True, False])
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
def test_resize_op_initializer_conversion(
create_model_with_resize_op, keep_io_types, low_precision_type
):
model, value_info_map, initializer_map, node_to_init_map = create_model_with_resize_op

converter = PrecisionConverter(
model,
value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=keep_io_types,
low_precision_type=low_precision_type,
)
converted_model = converter.convert(
high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node]
)
onnx.checker.check_model(converted_model)


@pytest.mark.parametrize("keep_io_types", [True, False])
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
def test_resize_op_tensor_scales_conversion(
create_model_with_resize_op_tensor_scales, keep_io_types, low_precision_type
):
model, value_info_map, initializer_map, node_to_init_map = (
create_model_with_resize_op_tensor_scales
)

converter = PrecisionConverter(
model,
value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=keep_io_types,
low_precision_type=low_precision_type,
)
converted_model = converter.convert(
high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node]
)
onnx.checker.check_model(converted_model)
Comment on lines +1306 to +1326
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add assertions to verify tensor scales precision handling.

Similar to the initializer test, this test lacks assertions for the core functionality. With tensor-based scales, you should verify that the scales_tensor output feeding into Resize remains FP32, and that appropriate Cast nodes are inserted.

Apply this diff to add the missing assertions:

     converted_model = converter.convert(
         high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node]
     )
     onnx.checker.check_model(converted_model)
+
+    # Verify roi remains FP32
+    roi_init = next(init for init in converted_model.graph.initializer if init.name == "roi")
+    assert roi_init.data_type == TensorProto.FLOAT, "roi should remain FP32"
+
+    # Verify scales_tensor (input to Resize) has Cast to FP32
+    # Since scales_add is in low_precision_nodes but feeds Resize which needs FP32 scales,
+    # there should be a Cast node inserted
+    cast_nodes = [n for n in converted_model.graph.node if n.op_type == "Cast"]
+    scales_cast = [c for c in cast_nodes if c.input[0] == "scales_tensor"]
+    assert len(scales_cast) > 0, "Cast node should be inserted for scales_tensor to Resize"
+    assert scales_cast[0].attribute[0].i == TensorProto.FLOAT, "scales_tensor should be cast to FP32"