Skip to content

Commit 9739398

Browse files
committed
Update base for Update on "Introduce public MergedDataMap"
Add public merged data map. Module can use this to resolve multiple named data maps. Differential Revision: [D83527299](https://our.internmc.facebook.com/intern/diff/D83527299/) [ghstack-poisoned]
2 parents 7be1cc2 + 881915d commit 9739398

30 files changed

+690
-269
lines changed

.github/workflows/pull.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,11 +970,16 @@ jobs:
970970
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build
971971
972972
# Test models serially
973-
models="mv2 mv3 edsr resnet18 resnet50 dl3"
973+
models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4"
974974
for model in $models; do
975975
python -m examples.vulkan.export --model_name=$model --test
976976
done
977977
978+
# For selected vision models, test with dynamic shapes
979+
models="mv2 resnet18 resnet50 ic3 densenet161"
980+
for model in $models; do
981+
python -m examples.vulkan.export --model_name=$model --test -d
982+
done
978983
979984
test-vulkan-operators-linux:
980985
name: test-vulkan-operators-linux

backends/arm/operators/op_conv2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ def define_node(
182182
acc_type = ts.DType.FP32
183183

184184
tosa_graph.addConst(
185-
[1], output.dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
185+
[1], inputs[0].dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
186186
)
187187
tosa_graph.addConst(
188188
[1],
189-
output.dtype,
189+
inputs[1].dtype,
190190
weight_zp,
191191
name=f"{conv2d_output_name}_weight_zp",
192192
)
@@ -269,7 +269,7 @@ def define_node(
269269

270270
# For quantized convolution, rescale the output value back to the same
271271
# integer value domain of the next op. Otherwise return float32 output.
272-
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
272+
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
273273
# Get scale_factor from input, weight, and output.
274274
input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61]
275275
per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61]

backends/arm/test/ops/test_linear.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
from typing import Tuple
1010

11-
import pytest
12-
1311
import torch
1412
from executorch.backends.arm.quantizer.arm_quantizer import (
1513
get_symmetric_a16w8_quantization_config,
@@ -313,12 +311,8 @@ def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
313311
pipeline.run()
314312

315313

316-
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
314+
@common.parametrize("test_data", test_data_all_16a8w)
317315
@common.XfailIfNoCorstone300
318-
@pytest.mark.xfail(
319-
reason="Ethos-U55 A16W8 linear: int16 matmul not yet supported; pending backend support or linear->conv1x1 lowering. See: https://github.com/pytorch/executorch/issues/13947",
320-
strict=False,
321-
)
322316
def test_linear_16a8w_u55_INT16(test_data: torch.Tensor):
323317
"""Test linear operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
324318
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -347,12 +341,8 @@ def test_linear_16a8w_u55_INT16(test_data: torch.Tensor):
347341
pipeline.run()
348342

349343

350-
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
344+
@common.parametrize("test_data", test_data_all_16a8w)
351345
@common.XfailIfNoCorstone320
352-
@pytest.mark.xfail(
353-
reason="Ethos-U55 A16W8 linear: int16 matmul not yet supported; pending backend support or linear->conv1x1 lowering. See: https://github.com/pytorch/executorch/issues/13947",
354-
strict=False,
355-
)
356346
def test_linear_16a8w_u85_INT16(test_data: torch.Tensor):
357347
"""Test linear operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
358348
test_data, out_features, has_bias, per_channel_quantization = test_data()

backends/cadence/aot/ref_implementations.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,3 +1416,159 @@ def im2row_per_tensor(
14161416
torch.tensor(in_zero_point, dtype=torch.int32),
14171417
channel_last,
14181418
)
1419+
1420+
1421+
@impl(m, "transposed_im2row")
1422+
def transposed_im2row(
1423+
input_tensor: torch.Tensor,
1424+
kernel_size: tuple[int, int],
1425+
dilation: tuple[int, int],
1426+
padding: tuple[int, int],
1427+
stride: tuple[int, int],
1428+
output_padding: tuple[int, int],
1429+
in_zero_point: torch.Tensor,
1430+
channel_last: bool = False,
1431+
) -> torch.Tensor:
1432+
"""
1433+
Converts input tensor patches into im2row format for transposed convolutions.
1434+
This function extracts patches from input in a pattern suitable for transposed convolution.
1435+
1436+
Args:
1437+
- input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D).
1438+
- kernel_size: Size of the convolution kernel.
1439+
- dilation: Dilation of the convolution kernel.
1440+
- padding: Padding to apply to the input.
1441+
- stride: Stride of the convolution.
1442+
- output_padding: Additional output padding for transposed convolution.
1443+
- in_zero_point: Zero point for input quantization (broadcastable to input).
1444+
- channel_last: If True, input is in NHWC format, else NCHW.
1445+
1446+
Returns:
1447+
- 3D tensor of shape (N, output_h * output_w, kernel_h * kernel_w * in_c)
1448+
"""
1449+
# Handle 1D convolution case by adding height dimension
1450+
if len(input_tensor.shape) == 3:
1451+
height_dim = 1 if channel_last else 2
1452+
input_tensor = input_tensor.unsqueeze(height_dim)
1453+
1454+
if in_zero_point is not None:
1455+
if in_zero_point.dtype != torch.int32:
1456+
raise ValueError("Input zero point must be an int32 tensor")
1457+
1458+
# Move to NCHW for processing if needed
1459+
if channel_last:
1460+
input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW
1461+
1462+
N, C, H_in, W_in = input_tensor.shape
1463+
1464+
# Output: (N, C*H_in*W_in, H_out, W_out)
1465+
H_out = (
1466+
(H_in - 1) * stride[0]
1467+
+ kernel_size[0]
1468+
+ output_padding[0]
1469+
- 2 * padding[0]
1470+
+ dilation[0] * (kernel_size[0] - 1)
1471+
)
1472+
W_out = (
1473+
(W_in - 1) * stride[1]
1474+
+ kernel_size[1]
1475+
+ output_padding[1]
1476+
- 2 * padding[1]
1477+
+ dilation[1] * (kernel_size[1] - 1)
1478+
)
1479+
1480+
# For each input pixel, create a channel where the upsampled (transposed conv) patch is placed
1481+
# Output: (N, C*H_in*W_in, H_out, W_out)
1482+
inp_flat = input_tensor.reshape(N, C * H_in * W_in)
1483+
1484+
# Calculate output spatial size
1485+
H_out = (
1486+
(H_in - 1) * stride[0]
1487+
- 2 * padding[0]
1488+
+ dilation[0] * (kernel_size[0] - 1)
1489+
+ output_padding[0]
1490+
+ 1
1491+
)
1492+
W_out = (
1493+
(W_in - 1) * stride[1]
1494+
- 2 * padding[1]
1495+
+ dilation[1] * (kernel_size[1] - 1)
1496+
+ output_padding[1]
1497+
+ 1
1498+
)
1499+
1500+
# Compute the upsampled (top-left) position for each input pixel
1501+
h_idx = torch.arange(H_in, device=input_tensor.device)
1502+
w_idx = torch.arange(W_in, device=input_tensor.device)
1503+
grid_h, grid_w = torch.meshgrid(h_idx, w_idx, indexing="ij")
1504+
out_h_idx = grid_h * stride[0] - padding[0]
1505+
out_w_idx = grid_w * stride[1] - padding[1]
1506+
1507+
# Compute all input pixel positions (flattened)
1508+
ch_idx = torch.arange(C * H_in * W_in, device=input_tensor.device)
1509+
ij_idx = ch_idx % (H_in * W_in)
1510+
i_idx = ij_idx // W_in
1511+
j_idx = ij_idx % W_in
1512+
1513+
# For each input pixel, compute the output positions for the kernel window
1514+
kh_idx = torch.arange(kernel_size[0], device=input_tensor.device)
1515+
kw_idx = torch.arange(kernel_size[1], device=input_tensor.device)
1516+
kh_grid, kw_grid = torch.meshgrid(kh_idx, kw_idx, indexing="ij")
1517+
kh_grid = kh_grid.reshape(-1)
1518+
kw_grid = kw_grid.reshape(-1)
1519+
num_kernel = kernel_size[0] * kernel_size[1]
1520+
1521+
# Broadcast to all channels and kernel positions
1522+
ch_idx_b = ch_idx.repeat_interleave(num_kernel)
1523+
n_kernel = ch_idx.shape[0] * num_kernel
1524+
1525+
i_idx_b = i_idx.repeat_interleave(num_kernel)
1526+
j_idx_b = j_idx.repeat_interleave(num_kernel)
1527+
kh_b = kh_grid.repeat(ch_idx.shape[0])
1528+
kw_b = kw_grid.repeat(ch_idx.shape[0])
1529+
1530+
h_out = out_h_idx[i_idx_b, j_idx_b] + kh_b * dilation[0]
1531+
w_out = out_w_idx[i_idx_b, j_idx_b] + kw_b * dilation[1]
1532+
1533+
# Mask for valid output positions
1534+
valid = (h_out >= 0) & (h_out < H_out) & (w_out >= 0) & (w_out < W_out)
1535+
1536+
# Prepare indices for advanced indexing
1537+
n_idx = (
1538+
torch.arange(N, device=input_tensor.device)
1539+
.view(-1, 1)
1540+
.expand(N, n_kernel)
1541+
.reshape(-1)
1542+
)
1543+
ch_idx_full = ch_idx_b.expand(N, n_kernel).reshape(-1)
1544+
h_out_full = h_out.expand(N, n_kernel).reshape(-1)
1545+
w_out_full = w_out.expand(N, n_kernel).reshape(-1)
1546+
valid_full = valid.expand(N, n_kernel).reshape(-1)
1547+
1548+
# Gather input values for each channel
1549+
inp_vals = inp_flat[:, ch_idx_b].reshape(-1)
1550+
1551+
# Create output tensor
1552+
patches = torch.zeros((N, C * H_in * W_in, H_out, W_out), dtype=input_tensor.dtype)
1553+
1554+
# If in_zero_point is provided, fill patches with it
1555+
if in_zero_point is not None:
1556+
if in_zero_point.numel() == 1:
1557+
patches.fill_(in_zero_point.item())
1558+
else:
1559+
# Broadcast in_zero_point to (N, C, H_in, W_in)
1560+
assert in_zero_point.shape == (N,)
1561+
in_zero_point = in_zero_point.view(N, 1, 1, 1)
1562+
patches = patches + in_zero_point
1563+
1564+
# Scatter input values to output positions (only valid positions)
1565+
patches[
1566+
n_idx[valid_full],
1567+
ch_idx_full[valid_full],
1568+
h_out_full[valid_full],
1569+
w_out_full[valid_full],
1570+
] = inp_vals[valid_full]
1571+
1572+
# Optionally, flatten to (N, num_patches, patch_size) if needed
1573+
patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous()
1574+
return patches

0 commit comments

Comments
 (0)