Skip to content

Commit 041a65a

Browse files
committed
Update on "Introduce public MergedDataMap"
Add public merged data map. Module can use this to resolve multiple named data maps. Differential Revision: [D83527299](https://our.internmc.facebook.com/intern/diff/D83527299/) [ghstack-poisoned]
2 parents e2592ae + 9739398 commit 041a65a

37 files changed

+740
-313
lines changed

.ci/scripts/build-qnn-sdk.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ set_up_aot() {
3838
-DEXECUTORCH_BUILD_EXTENSION_EXTENSION_LLM=ON \
3939
-DEXECUTORCH_BUILD_EXTENSION_EXTENSION_LLM_RUNNER=ON \
4040
-DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \
41+
-DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \
4142
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
4243
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
4344
-DPYTHON_EXECUTABLE=python3

.github/workflows/pull.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,11 +970,16 @@ jobs:
970970
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build
971971
972972
# Test models serially
973-
models="mv2 mv3 edsr resnet18 resnet50 dl3"
973+
models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4"
974974
for model in $models; do
975975
python -m examples.vulkan.export --model_name=$model --test
976976
done
977977
978+
# For selected vision models, test with dynamic shapes
979+
models="mv2 resnet18 resnet50 ic3 densenet161"
980+
for model in $models; do
981+
python -m examples.vulkan.export --model_name=$model --test -d
982+
done
978983
979984
test-vulkan-operators-linux:
980985
name: test-vulkan-operators-linux

backends/arm/operators/op_conv2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ def define_node(
182182
acc_type = ts.DType.FP32
183183

184184
tosa_graph.addConst(
185-
[1], output.dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
185+
[1], inputs[0].dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
186186
)
187187
tosa_graph.addConst(
188188
[1],
189-
output.dtype,
189+
inputs[1].dtype,
190190
weight_zp,
191191
name=f"{conv2d_output_name}_weight_zp",
192192
)
@@ -269,7 +269,7 @@ def define_node(
269269

270270
# For quantized convolution, rescale the output value back to the same
271271
# integer value domain of the next op. Otherwise return float32 output.
272-
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
272+
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
273273
# Get scale_factor from input, weight, and output.
274274
input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61]
275275
per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61]

backends/arm/test/ops/test_linear.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
from typing import Tuple
1010

11-
import pytest
12-
1311
import torch
1412
from executorch.backends.arm.quantizer.arm_quantizer import (
1513
get_symmetric_a16w8_quantization_config,
@@ -313,12 +311,8 @@ def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
313311
pipeline.run()
314312

315313

316-
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
314+
@common.parametrize("test_data", test_data_all_16a8w)
317315
@common.XfailIfNoCorstone300
318-
@pytest.mark.xfail(
319-
reason="Ethos-U55 A16W8 linear: int16 matmul not yet supported; pending backend support or linear->conv1x1 lowering. See: https://github.com/pytorch/executorch/issues/13947",
320-
strict=False,
321-
)
322316
def test_linear_16a8w_u55_INT16(test_data: torch.Tensor):
323317
"""Test linear operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
324318
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -347,12 +341,8 @@ def test_linear_16a8w_u55_INT16(test_data: torch.Tensor):
347341
pipeline.run()
348342

349343

350-
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
344+
@common.parametrize("test_data", test_data_all_16a8w)
351345
@common.XfailIfNoCorstone320
352-
@pytest.mark.xfail(
353-
reason="Ethos-U55 A16W8 linear: int16 matmul not yet supported; pending backend support or linear->conv1x1 lowering. See: https://github.com/pytorch/executorch/issues/13947",
354-
strict=False,
355-
)
356346
def test_linear_16a8w_u85_INT16(test_data: torch.Tensor):
357347
"""Test linear operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
358348
test_data, out_features, has_bias, per_channel_quantization = test_data()

backends/cadence/aot/ref_implementations.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,3 +1416,159 @@ def im2row_per_tensor(
14161416
torch.tensor(in_zero_point, dtype=torch.int32),
14171417
channel_last,
14181418
)
1419+
1420+
1421+
@impl(m, "transposed_im2row")
1422+
def transposed_im2row(
1423+
input_tensor: torch.Tensor,
1424+
kernel_size: tuple[int, int],
1425+
dilation: tuple[int, int],
1426+
padding: tuple[int, int],
1427+
stride: tuple[int, int],
1428+
output_padding: tuple[int, int],
1429+
in_zero_point: torch.Tensor,
1430+
channel_last: bool = False,
1431+
) -> torch.Tensor:
1432+
"""
1433+
Converts input tensor patches into im2row format for transposed convolutions.
1434+
This function extracts patches from input in a pattern suitable for transposed convolution.
1435+
1436+
Args:
1437+
- input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D).
1438+
- kernel_size: Size of the convolution kernel.
1439+
- dilation: Dilation of the convolution kernel.
1440+
- padding: Padding to apply to the input.
1441+
- stride: Stride of the convolution.
1442+
- output_padding: Additional output padding for transposed convolution.
1443+
- in_zero_point: Zero point for input quantization (broadcastable to input).
1444+
- channel_last: If True, input is in NHWC format, else NCHW.
1445+
1446+
Returns:
1447+
- 3D tensor of shape (N, output_h * output_w, kernel_h * kernel_w * in_c)
1448+
"""
1449+
# Handle 1D convolution case by adding height dimension
1450+
if len(input_tensor.shape) == 3:
1451+
height_dim = 1 if channel_last else 2
1452+
input_tensor = input_tensor.unsqueeze(height_dim)
1453+
1454+
if in_zero_point is not None:
1455+
if in_zero_point.dtype != torch.int32:
1456+
raise ValueError("Input zero point must be an int32 tensor")
1457+
1458+
# Move to NCHW for processing if needed
1459+
if channel_last:
1460+
input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW
1461+
1462+
N, C, H_in, W_in = input_tensor.shape
1463+
1464+
# Output: (N, C*H_in*W_in, H_out, W_out)
1465+
H_out = (
1466+
(H_in - 1) * stride[0]
1467+
+ kernel_size[0]
1468+
+ output_padding[0]
1469+
- 2 * padding[0]
1470+
+ dilation[0] * (kernel_size[0] - 1)
1471+
)
1472+
W_out = (
1473+
(W_in - 1) * stride[1]
1474+
+ kernel_size[1]
1475+
+ output_padding[1]
1476+
- 2 * padding[1]
1477+
+ dilation[1] * (kernel_size[1] - 1)
1478+
)
1479+
1480+
# For each input pixel, create a channel where the upsampled (transposed conv) patch is placed
1481+
# Output: (N, C*H_in*W_in, H_out, W_out)
1482+
inp_flat = input_tensor.reshape(N, C * H_in * W_in)
1483+
1484+
# Calculate output spatial size
1485+
H_out = (
1486+
(H_in - 1) * stride[0]
1487+
- 2 * padding[0]
1488+
+ dilation[0] * (kernel_size[0] - 1)
1489+
+ output_padding[0]
1490+
+ 1
1491+
)
1492+
W_out = (
1493+
(W_in - 1) * stride[1]
1494+
- 2 * padding[1]
1495+
+ dilation[1] * (kernel_size[1] - 1)
1496+
+ output_padding[1]
1497+
+ 1
1498+
)
1499+
1500+
# Compute the upsampled (top-left) position for each input pixel
1501+
h_idx = torch.arange(H_in, device=input_tensor.device)
1502+
w_idx = torch.arange(W_in, device=input_tensor.device)
1503+
grid_h, grid_w = torch.meshgrid(h_idx, w_idx, indexing="ij")
1504+
out_h_idx = grid_h * stride[0] - padding[0]
1505+
out_w_idx = grid_w * stride[1] - padding[1]
1506+
1507+
# Compute all input pixel positions (flattened)
1508+
ch_idx = torch.arange(C * H_in * W_in, device=input_tensor.device)
1509+
ij_idx = ch_idx % (H_in * W_in)
1510+
i_idx = ij_idx // W_in
1511+
j_idx = ij_idx % W_in
1512+
1513+
# For each input pixel, compute the output positions for the kernel window
1514+
kh_idx = torch.arange(kernel_size[0], device=input_tensor.device)
1515+
kw_idx = torch.arange(kernel_size[1], device=input_tensor.device)
1516+
kh_grid, kw_grid = torch.meshgrid(kh_idx, kw_idx, indexing="ij")
1517+
kh_grid = kh_grid.reshape(-1)
1518+
kw_grid = kw_grid.reshape(-1)
1519+
num_kernel = kernel_size[0] * kernel_size[1]
1520+
1521+
# Broadcast to all channels and kernel positions
1522+
ch_idx_b = ch_idx.repeat_interleave(num_kernel)
1523+
n_kernel = ch_idx.shape[0] * num_kernel
1524+
1525+
i_idx_b = i_idx.repeat_interleave(num_kernel)
1526+
j_idx_b = j_idx.repeat_interleave(num_kernel)
1527+
kh_b = kh_grid.repeat(ch_idx.shape[0])
1528+
kw_b = kw_grid.repeat(ch_idx.shape[0])
1529+
1530+
h_out = out_h_idx[i_idx_b, j_idx_b] + kh_b * dilation[0]
1531+
w_out = out_w_idx[i_idx_b, j_idx_b] + kw_b * dilation[1]
1532+
1533+
# Mask for valid output positions
1534+
valid = (h_out >= 0) & (h_out < H_out) & (w_out >= 0) & (w_out < W_out)
1535+
1536+
# Prepare indices for advanced indexing
1537+
n_idx = (
1538+
torch.arange(N, device=input_tensor.device)
1539+
.view(-1, 1)
1540+
.expand(N, n_kernel)
1541+
.reshape(-1)
1542+
)
1543+
ch_idx_full = ch_idx_b.expand(N, n_kernel).reshape(-1)
1544+
h_out_full = h_out.expand(N, n_kernel).reshape(-1)
1545+
w_out_full = w_out.expand(N, n_kernel).reshape(-1)
1546+
valid_full = valid.expand(N, n_kernel).reshape(-1)
1547+
1548+
# Gather input values for each channel
1549+
inp_vals = inp_flat[:, ch_idx_b].reshape(-1)
1550+
1551+
# Create output tensor
1552+
patches = torch.zeros((N, C * H_in * W_in, H_out, W_out), dtype=input_tensor.dtype)
1553+
1554+
# If in_zero_point is provided, fill patches with it
1555+
if in_zero_point is not None:
1556+
if in_zero_point.numel() == 1:
1557+
patches.fill_(in_zero_point.item())
1558+
else:
1559+
# Broadcast in_zero_point to (N, C, H_in, W_in)
1560+
assert in_zero_point.shape == (N,)
1561+
in_zero_point = in_zero_point.view(N, 1, 1, 1)
1562+
patches = patches + in_zero_point
1563+
1564+
# Scatter input values to output positions (only valid positions)
1565+
patches[
1566+
n_idx[valid_full],
1567+
ch_idx_full[valid_full],
1568+
h_out_full[valid_full],
1569+
w_out_full[valid_full],
1570+
] = inp_vals[valid_full]
1571+
1572+
# Optionally, flatten to (N, num_patches, patch_size) if needed
1573+
patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous()
1574+
return patches

0 commit comments

Comments
 (0)