Skip to content

Commit 4372a14

Browse files
Fix const prop pass when a const prop tensor has zero stride, make it contiguous (#14725)
1 parent 0145604 commit 4372a14

File tree

2 files changed

+80
-1
lines changed

2 files changed

+80
-1
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,14 @@ def get_propagated_const_tensor_dict(
164164
with torch.no_grad():
165165
# Execute the `node.target` and create a new propagated constant tensor.
166166
prop_constant_tensor = node.target(*args_data, **kwargs_data)
167+
168+
# ExecuTorch doesn't support zero strides, so we need to ensure the tensor is contiguous
169+
# if it has any zero strides from broadcasting/expansion operations
170+
if (
171+
isinstance(prop_constant_tensor, torch.Tensor)
172+
and 0 in prop_constant_tensor.stride()
173+
):
174+
prop_constant_tensor = prop_constant_tensor.contiguous()
167175
const_node_to_tensor[node] = prop_constant_tensor
168176

169177
return const_node_to_tensor

exir/tests/test_passes.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,17 @@
2424
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
2525
QuantizationConfig,
2626
)
27-
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
27+
from executorch.backends.xnnpack.utils.configs import (
28+
get_xnnpack_executorch_backend_config,
29+
)
30+
31+
from executorch.exir import (
32+
EdgeCompileConfig,
33+
EdgeProgramManager,
34+
memory,
35+
to_edge,
36+
to_edge_transform_and_lower,
37+
)
2838
from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
2939
from executorch.exir.dialects.edge._ops import EdgeOpOverload
3040
from executorch.exir.emit import emit_program
@@ -2022,3 +2032,64 @@ def forward(self, x):
20222032
pass_result = constant_prop_pass(edge.exported_program())
20232033
# 1 constant: a (= self.w @ self.cst)
20242034
self.assertEqual(1, len(pass_result.constants))
2035+
2036+
def test_constant_prop_pass_zero_stride_tensors(self) -> None:
2037+
"""
2038+
Test that constant propagation correctly handles tensors with zero strides
2039+
by converting them to contiguous tensors. Zero-stride tensors can be created
2040+
by operations like expand() and are not supported by ExecuTorch.
2041+
"""
2042+
2043+
class ZeroStrideModel(torch.nn.Module):
2044+
def __init__(self):
2045+
super().__init__()
2046+
self.const_param = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
2047+
2048+
def forward(self, x):
2049+
unsqueezed = self.const_param.unsqueeze(
2050+
1
2051+
) # Shape: (3, 1), strides: (1, 1)
2052+
# expand creates zero-stride tensor
2053+
expanded = unsqueezed.expand(3, 5) # Shape: (3, 5), strides: (1, 0)
2054+
2055+
# Use the expanded tensor with the input to prevent elimination
2056+
result = x + expanded.sum()
2057+
return result
2058+
2059+
model = ZeroStrideModel()
2060+
x = torch.randn(3, 5)
2061+
exported = torch.export.export(model, (x,))
2062+
2063+
# Before constant prop: verify we have the parameter
2064+
self.assertIn("const_param", exported.state_dict)
2065+
2066+
const_prop_result = constant_prop_pass(exported)
2067+
lowered = to_edge_transform_and_lower(
2068+
const_prop_result,
2069+
partitioner=[XnnpackPartitioner()],
2070+
)
2071+
2072+
# Should go through
2073+
lowered.to_executorch(get_xnnpack_executorch_backend_config([SpecPropPass()]))
2074+
self.assertGreater(len(const_prop_result.constants), 0)
2075+
2076+
# Find the propagated constant tensor
2077+
prop_tensor = None
2078+
for constant_name, constant_tensor in const_prop_result.constants.items():
2079+
if constant_name.startswith("_prop_tensor_constant"):
2080+
prop_tensor = constant_tensor
2081+
break
2082+
2083+
# Verify the propagated tensor exists and has no zero strides
2084+
self.assertIsNotNone(prop_tensor)
2085+
self.assertNotIn(
2086+
0,
2087+
prop_tensor.stride(),
2088+
f"Propagated tensor still has zero stride: {prop_tensor.stride()}",
2089+
)
2090+
2091+
# Verify the tensor is contiguous
2092+
self.assertTrue(
2093+
prop_tensor.is_contiguous(),
2094+
f"Propagated tensor is not contiguous: {prop_tensor.stride()}",
2095+
)

0 commit comments

Comments
 (0)