Skip to content

Commit c2be5db

Browse files
author
Github Executorch
committed
Title:
[cortex_m] Fix linear weight layout: transpose in AOT pass, align meta/ref impl Summary: The linear path in ConvertToCortexMPass was not transposing weights unlike conv2d, causing inconsistency with the C++ runtime which expects weights in [in_features, out_features] format per CMSIS-NN. Changes: - convert_to_cortex_m_pass.py: Transpose linear weights [out, in] -> [in, out] - operators.py: Update meta to use weights.shape[1] for output dimension - operators.py: Remove .T from ref impl (weights pre-transposed by pass) Fixes MV2 output shape mismatch: [1, 1280] -> [1, 1000] MV2 on Corstone-300/E8 with CMSIS-NN kernels This fix ensures the AOT-compiled .pte file has correctly shaped output tensors for any model using quantized_linear (MV2, ResNet, MV3, etc.).
1 parent 5b3d9fc commit c2be5db

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

backends/cortex_m/ops/operators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def quantized_linear_meta(
352352
activation_min,
353353
) -> torch.Tensor:
354354

355-
shape = (*input.shape[:-1], weights.shape[0])
355+
shape = (*input.shape[:-1], weights.shape[1])
356356
return torch.empty(shape, dtype=input.dtype, device=input.device)
357357

358358

@@ -386,7 +386,7 @@ def quantized_linear_impl(
386386
input_reshaped = input_int32.reshape(new_shape)
387387

388388
lhs_sum = torch.sum(input_reshaped, dim=-1, keepdim=True) * filter_offset
389-
output = torch.mm(input_reshaped, weights_int32.T) + lhs_sum + kernel_sum
389+
output = torch.mm(input_reshaped, weights_int32) + lhs_sum + kernel_sum
390390
output_shape = (*input.shape[:-1], output.shape[-1])
391391
output_reshaped = output.reshape(output_shape)
392392
else:
@@ -396,7 +396,7 @@ def quantized_linear_impl(
396396
new_shape = (prod(input.shape[:-1]), input.shape[-1])
397397
input_reshaped = input_int32.reshape(new_shape)
398398

399-
output = torch.mm(input_reshaped, weights_int32.T)
399+
output = torch.mm(input_reshaped, weights_int32)
400400
if bias is not None:
401401
output = output + bias
402402
output_shape = (*input.shape[:-1], output.shape[-1])

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ def _get_linear_replacement(self, node):
113113
kernel_sum_tensor = self._compute_kernel_sum(
114114
weights_tensor, bias_tensor, -input_zp, -weight_zp
115115
)
116+
117+
# Transpose weights from PyTorch format [out_features, in_features]
118+
# to CMSIS-NN format [in_features, out_features]
119+
weights_transposed = weights_tensor.T.contiguous()
120+
116121
with node.graph.inserting_after(weights):
117122
kernel_sum = create_constant_placeholder(
118123
self.exported_program,
@@ -122,9 +127,17 @@ def _get_linear_replacement(self, node):
122127
kernel_sum_tensor,
123128
)
124129

130+
weights_transposed_node = create_constant_placeholder(
131+
self.exported_program,
132+
node.graph,
133+
node.name + "_weights_transposed",
134+
InputKind.PARAMETER,
135+
weights_transposed,
136+
)
137+
125138
args = (
126139
node.args[0],
127-
weights,
140+
weights_transposed_node,
128141
None,
129142
kernel_sum,
130143
-input_zp,

0 commit comments

Comments
 (0)