Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def _is_supported_on_target(
if dim == 0:
return False

# If all input shapes are equal, the neutron is able to pad the last dimension of inputs and outputs.
input_shapes = [_get_shape(input_) for input_ in node.all_input_nodes]
if input_shapes.count(input_shapes[0]) == len(input_shapes):
if dim == len(input_shapes[0]) - 1:
return True

# Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the
# last dimension, depending on the formats of the node. The format, however, cannot be determined
# during conversion, as it depends on what other nodes are delegated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,46 @@ def test_cat__force_delegate():
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
)
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)


def test_cat__same_shapes_converter_padding_last_dimension():
target = "imxrt700"

# The Converter is capable of padding the last dimension of `cat` with the same input shapes.
input_shape = (3, 1, 3)

quantized_program = to_quantized_edge_program(
CatModule(2),
[input_shape, input_shape],
target=target,
neutron_converter_flavor="SDK_25_09",
custom_delegation_options=CustomDelegationOptions(),
).exported_program()

# Make sure the `Cat` was delegated.
assert not graph_contains_any_of_ops(
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
)
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)


def test_cat__same_shapes_converter_padding_middle_dimension():
target = "imxrt700"

# The Converter is not capable of padding the middle dimensions of `cat` with the same input shapes.
input_shape = (3, 1, 3)

quantized_program = to_quantized_edge_program(
CatModule(1),
[input_shape, input_shape],
target=target,
custom_delegation_options=CustomDelegationOptions(),
).exported_program()

# Make sure the `Cat` was NOT delegated.
assert graph_contains_any_of_ops(
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
)
assert not any(
"lowered_module" in node.name for node in quantized_program.graph.nodes
)
Loading