|
24 | 24 | from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
|
25 | 25 | QuantizationConfig,
|
26 | 26 | )
|
27 |
| -from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge |
| 27 | +from executorch.backends.xnnpack.utils.configs import ( |
| 28 | + get_xnnpack_executorch_backend_config, |
| 29 | +) |
| 30 | + |
| 31 | +from executorch.exir import ( |
| 32 | + EdgeCompileConfig, |
| 33 | + EdgeProgramManager, |
| 34 | + memory, |
| 35 | + to_edge, |
| 36 | + to_edge_transform_and_lower, |
| 37 | +) |
28 | 38 | from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
|
29 | 39 | from executorch.exir.dialects.edge._ops import EdgeOpOverload
|
30 | 40 | from executorch.exir.emit import emit_program
|
@@ -2022,3 +2032,64 @@ def forward(self, x):
|
2022 | 2032 | pass_result = constant_prop_pass(edge.exported_program())
|
2023 | 2033 | # 1 constant: a (= self.w @ self.cst)
|
2024 | 2034 | self.assertEqual(1, len(pass_result.constants))
|
| 2035 | + |
| 2036 | + def test_constant_prop_pass_zero_stride_tensors(self) -> None: |
| 2037 | + """ |
| 2038 | + Test that constant propagation correctly handles tensors with zero strides |
| 2039 | + by converting them to contiguous tensors. Zero-stride tensors can be created |
| 2040 | + by operations like expand() and are not supported by ExecuTorch. |
| 2041 | + """ |
| 2042 | + |
| 2043 | + class ZeroStrideModel(torch.nn.Module): |
| 2044 | + def __init__(self): |
| 2045 | + super().__init__() |
| 2046 | + self.const_param = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0])) |
| 2047 | + |
| 2048 | + def forward(self, x): |
| 2049 | + unsqueezed = self.const_param.unsqueeze( |
| 2050 | + 1 |
| 2051 | + ) # Shape: (3, 1), strides: (1, 1) |
| 2052 | + # expand creates zero-stride tensor |
| 2053 | + expanded = unsqueezed.expand(3, 5) # Shape: (3, 5), strides: (1, 0) |
| 2054 | + |
| 2055 | + # Use the expanded tensor with the input to prevent elimination |
| 2056 | + result = x + expanded.sum() |
| 2057 | + return result |
| 2058 | + |
| 2059 | + model = ZeroStrideModel() |
| 2060 | + x = torch.randn(3, 5) |
| 2061 | + exported = torch.export.export(model, (x,)) |
| 2062 | + |
| 2063 | + # Before constant prop: verify we have the parameter |
| 2064 | + self.assertIn("const_param", exported.state_dict) |
| 2065 | + |
| 2066 | + const_prop_result = constant_prop_pass(exported) |
| 2067 | + lowered = to_edge_transform_and_lower( |
| 2068 | + const_prop_result, |
| 2069 | + partitioner=[XnnpackPartitioner()], |
| 2070 | + ) |
| 2071 | + |
| 2072 | + # Should go through |
| 2073 | + lowered.to_executorch(get_xnnpack_executorch_backend_config([SpecPropPass()])) |
| 2074 | + self.assertGreater(len(const_prop_result.constants), 0) |
| 2075 | + |
| 2076 | + # Find the propagated constant tensor |
| 2077 | + prop_tensor = None |
| 2078 | + for constant_name, constant_tensor in const_prop_result.constants.items(): |
| 2079 | + if constant_name.startswith("_prop_tensor_constant"): |
| 2080 | + prop_tensor = constant_tensor |
| 2081 | + break |
| 2082 | + |
| 2083 | + # Verify the propagated tensor exists and has no zero strides |
| 2084 | + self.assertIsNotNone(prop_tensor) |
| 2085 | + self.assertNotIn( |
| 2086 | + 0, |
| 2087 | + prop_tensor.stride(), |
| 2088 | + f"Propagated tensor still has zero stride: {prop_tensor.stride()}", |
| 2089 | + ) |
| 2090 | + |
| 2091 | + # Verify the tensor is contiguous |
| 2092 | + self.assertTrue( |
| 2093 | + prop_tensor.is_contiguous(), |
| 2094 | + f"Propagated tensor is not contiguous: {prop_tensor.stride()}", |
| 2095 | + ) |
0 commit comments