|  | 
| 24 | 24 | from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( | 
| 25 | 25 |     QuantizationConfig, | 
| 26 | 26 | ) | 
| 27 |  | -from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge | 
|  | 27 | +from executorch.backends.xnnpack.utils.configs import ( | 
|  | 28 | +    get_xnnpack_executorch_backend_config, | 
|  | 29 | +) | 
|  | 30 | + | 
|  | 31 | +from executorch.exir import ( | 
|  | 32 | +    EdgeCompileConfig, | 
|  | 33 | +    EdgeProgramManager, | 
|  | 34 | +    memory, | 
|  | 35 | +    to_edge, | 
|  | 36 | +    to_edge_transform_and_lower, | 
|  | 37 | +) | 
| 28 | 38 | from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops | 
| 29 | 39 | from executorch.exir.dialects.edge._ops import EdgeOpOverload | 
| 30 | 40 | from executorch.exir.emit import emit_program | 
| @@ -2022,3 +2032,64 @@ def forward(self, x): | 
| 2022 | 2032 |         pass_result = constant_prop_pass(edge.exported_program()) | 
| 2023 | 2033 |         # 1 constant: a (= self.w @ self.cst) | 
| 2024 | 2034 |         self.assertEqual(1, len(pass_result.constants)) | 
|  | 2035 | + | 
|  | 2036 | +    def test_constant_prop_pass_zero_stride_tensors(self) -> None: | 
|  | 2037 | +        """ | 
|  | 2038 | +        Test that constant propagation correctly handles tensors with zero strides | 
|  | 2039 | +        by converting them to contiguous tensors. Zero-stride tensors can be created | 
|  | 2040 | +        by operations like expand() and are not supported by ExecuTorch. | 
|  | 2041 | +        """ | 
|  | 2042 | + | 
|  | 2043 | +        class ZeroStrideModel(torch.nn.Module): | 
|  | 2044 | +            def __init__(self): | 
|  | 2045 | +                super().__init__() | 
|  | 2046 | +                self.const_param = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0])) | 
|  | 2047 | + | 
|  | 2048 | +            def forward(self, x): | 
|  | 2049 | +                unsqueezed = self.const_param.unsqueeze( | 
|  | 2050 | +                    1 | 
|  | 2051 | +                )  # Shape: (3, 1), strides: (1, 1) | 
|  | 2052 | +                # expand creates zero-stride tensor | 
|  | 2053 | +                expanded = unsqueezed.expand(3, 5)  # Shape: (3, 5), strides: (1, 0) | 
|  | 2054 | + | 
|  | 2055 | +                # Use the expanded tensor with the input to prevent elimination | 
|  | 2056 | +                result = x + expanded.sum() | 
|  | 2057 | +                return result | 
|  | 2058 | + | 
|  | 2059 | +        model = ZeroStrideModel() | 
|  | 2060 | +        x = torch.randn(3, 5) | 
|  | 2061 | +        exported = torch.export.export(model, (x,)) | 
|  | 2062 | + | 
|  | 2063 | +        # Before constant prop: verify we have the parameter | 
|  | 2064 | +        self.assertIn("const_param", exported.state_dict) | 
|  | 2065 | + | 
|  | 2066 | +        const_prop_result = constant_prop_pass(exported) | 
|  | 2067 | +        lowered = to_edge_transform_and_lower( | 
|  | 2068 | +            const_prop_result, | 
|  | 2069 | +            partitioner=[XnnpackPartitioner()], | 
|  | 2070 | +        ) | 
|  | 2071 | + | 
|  | 2072 | +        # Should go through | 
|  | 2073 | +        lowered.to_executorch(get_xnnpack_executorch_backend_config([SpecPropPass()])) | 
|  | 2074 | +        self.assertGreater(len(const_prop_result.constants), 0) | 
|  | 2075 | + | 
|  | 2076 | +        # Find the propagated constant tensor | 
|  | 2077 | +        prop_tensor = None | 
|  | 2078 | +        for constant_name, constant_tensor in const_prop_result.constants.items(): | 
|  | 2079 | +            if constant_name.startswith("_prop_tensor_constant"): | 
|  | 2080 | +                prop_tensor = constant_tensor | 
|  | 2081 | +                break | 
|  | 2082 | + | 
|  | 2083 | +        # Verify the propagated tensor exists and has no zero strides | 
|  | 2084 | +        self.assertIsNotNone(prop_tensor) | 
|  | 2085 | +        self.assertNotIn( | 
|  | 2086 | +            0, | 
|  | 2087 | +            prop_tensor.stride(), | 
|  | 2088 | +            f"Propagated tensor still has zero stride: {prop_tensor.stride()}", | 
|  | 2089 | +        ) | 
|  | 2090 | + | 
|  | 2091 | +        # Verify the tensor is contiguous | 
|  | 2092 | +        self.assertTrue( | 
|  | 2093 | +            prop_tensor.is_contiguous(), | 
|  | 2094 | +            f"Propagated tensor is not contiguous: {prop_tensor.stride()}", | 
|  | 2095 | +        ) | 
0 commit comments