Skip to content

Commit a5f3055

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Add transposed im2row (pytorch#14738)
Summary: Continued support for custom cadence ops. Reviewed By: hsharma35, eigen-k Differential Revision: D83709868
1 parent b021fd0 commit a5f3055

File tree

2 files changed

+326
-0
lines changed

2 files changed

+326
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,3 +1416,159 @@ def im2row_per_tensor(
14161416
torch.tensor(in_zero_point, dtype=torch.int32),
14171417
channel_last,
14181418
)
1419+
1420+
1421+
@impl(m, "transposed_im2row")
1422+
def transposed_im2row(
1423+
input_tensor: torch.Tensor,
1424+
kernel_size: tuple[int, int],
1425+
dilation: tuple[int, int],
1426+
padding: tuple[int, int],
1427+
stride: tuple[int, int],
1428+
output_padding: tuple[int, int],
1429+
in_zero_point: torch.Tensor,
1430+
channel_last: bool = False,
1431+
) -> torch.Tensor:
1432+
"""
1433+
Converts input tensor patches into im2row format for transposed convolutions.
1434+
This function extracts patches from input in a pattern suitable for transposed convolution.
1435+
1436+
Args:
1437+
- input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D).
1438+
- kernel_size: Size of the convolution kernel.
1439+
- dilation: Dilation of the convolution kernel.
1440+
- padding: Padding to apply to the input.
1441+
- stride: Stride of the convolution.
1442+
- output_padding: Additional output padding for transposed convolution.
1443+
- in_zero_point: Zero point for input quantization (broadcastable to input).
1444+
- channel_last: If True, input is in NHWC format, else NCHW.
1445+
1446+
Returns:
1447+
- 3D tensor of shape (N, output_h * output_w, kernel_h * kernel_w * in_c)
1448+
"""
1449+
# Handle 1D convolution case by adding height dimension
1450+
if len(input_tensor.shape) == 3:
1451+
height_dim = 1 if channel_last else 2
1452+
input_tensor = input_tensor.unsqueeze(height_dim)
1453+
1454+
if in_zero_point is not None:
1455+
if in_zero_point.dtype != torch.int32:
1456+
raise ValueError("Input zero point must be an int32 tensor")
1457+
1458+
# Move to NCHW for processing if needed
1459+
if channel_last:
1460+
input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW
1461+
1462+
N, C, H_in, W_in = input_tensor.shape
1463+
1464+
# Output: (N, C*H_in*W_in, H_out, W_out)
1465+
H_out = (
1466+
(H_in - 1) * stride[0]
1467+
+ kernel_size[0]
1468+
+ output_padding[0]
1469+
- 2 * padding[0]
1470+
+ dilation[0] * (kernel_size[0] - 1)
1471+
)
1472+
W_out = (
1473+
(W_in - 1) * stride[1]
1474+
+ kernel_size[1]
1475+
+ output_padding[1]
1476+
- 2 * padding[1]
1477+
+ dilation[1] * (kernel_size[1] - 1)
1478+
)
1479+
1480+
# For each input pixel, create a channel where the upsampled (transposed conv) patch is placed
1481+
# Output: (N, C*H_in*W_in, H_out, W_out)
1482+
inp_flat = input_tensor.reshape(N, C * H_in * W_in)
1483+
1484+
# Calculate output spatial size
1485+
H_out = (
1486+
(H_in - 1) * stride[0]
1487+
- 2 * padding[0]
1488+
+ dilation[0] * (kernel_size[0] - 1)
1489+
+ output_padding[0]
1490+
+ 1
1491+
)
1492+
W_out = (
1493+
(W_in - 1) * stride[1]
1494+
- 2 * padding[1]
1495+
+ dilation[1] * (kernel_size[1] - 1)
1496+
+ output_padding[1]
1497+
+ 1
1498+
)
1499+
1500+
# Compute the upsampled (top-left) position for each input pixel
1501+
h_idx = torch.arange(H_in, device=input_tensor.device)
1502+
w_idx = torch.arange(W_in, device=input_tensor.device)
1503+
grid_h, grid_w = torch.meshgrid(h_idx, w_idx, indexing="ij")
1504+
out_h_idx = grid_h * stride[0] - padding[0]
1505+
out_w_idx = grid_w * stride[1] - padding[1]
1506+
1507+
# Compute all input pixel positions (flattened)
1508+
ch_idx = torch.arange(C * H_in * W_in, device=input_tensor.device)
1509+
ij_idx = ch_idx % (H_in * W_in)
1510+
i_idx = ij_idx // W_in
1511+
j_idx = ij_idx % W_in
1512+
1513+
# For each input pixel, compute the output positions for the kernel window
1514+
kh_idx = torch.arange(kernel_size[0], device=input_tensor.device)
1515+
kw_idx = torch.arange(kernel_size[1], device=input_tensor.device)
1516+
kh_grid, kw_grid = torch.meshgrid(kh_idx, kw_idx, indexing="ij")
1517+
kh_grid = kh_grid.reshape(-1)
1518+
kw_grid = kw_grid.reshape(-1)
1519+
num_kernel = kernel_size[0] * kernel_size[1]
1520+
1521+
# Broadcast to all channels and kernel positions
1522+
ch_idx_b = ch_idx.repeat_interleave(num_kernel)
1523+
n_kernel = ch_idx.shape[0] * num_kernel
1524+
1525+
i_idx_b = i_idx.repeat_interleave(num_kernel)
1526+
j_idx_b = j_idx.repeat_interleave(num_kernel)
1527+
kh_b = kh_grid.repeat(ch_idx.shape[0])
1528+
kw_b = kw_grid.repeat(ch_idx.shape[0])
1529+
1530+
h_out = out_h_idx[i_idx_b, j_idx_b] + kh_b * dilation[0]
1531+
w_out = out_w_idx[i_idx_b, j_idx_b] + kw_b * dilation[1]
1532+
1533+
# Mask for valid output positions
1534+
valid = (h_out >= 0) & (h_out < H_out) & (w_out >= 0) & (w_out < W_out)
1535+
1536+
# Prepare indices for advanced indexing
1537+
n_idx = (
1538+
torch.arange(N, device=input_tensor.device)
1539+
.view(-1, 1)
1540+
.expand(N, n_kernel)
1541+
.reshape(-1)
1542+
)
1543+
ch_idx_full = ch_idx_b.expand(N, n_kernel).reshape(-1)
1544+
h_out_full = h_out.expand(N, n_kernel).reshape(-1)
1545+
w_out_full = w_out.expand(N, n_kernel).reshape(-1)
1546+
valid_full = valid.expand(N, n_kernel).reshape(-1)
1547+
1548+
# Gather input values for each channel
1549+
inp_vals = inp_flat[:, ch_idx_b].reshape(-1)
1550+
1551+
# Create output tensor
1552+
patches = torch.zeros((N, C * H_in * W_in, H_out, W_out), dtype=input_tensor.dtype)
1553+
1554+
# If in_zero_point is provided, fill patches with it
1555+
if in_zero_point is not None:
1556+
if in_zero_point.numel() == 1:
1557+
patches.fill_(in_zero_point.item())
1558+
else:
1559+
# Broadcast in_zero_point to (N, C, H_in, W_in)
1560+
assert in_zero_point.shape == (N,)
1561+
in_zero_point = in_zero_point.view(N, 1, 1, 1)
1562+
patches = patches + in_zero_point
1563+
1564+
# Scatter input values to output positions (only valid positions)
1565+
patches[
1566+
n_idx[valid_full],
1567+
ch_idx_full[valid_full],
1568+
h_out_full[valid_full],
1569+
w_out_full[valid_full],
1570+
] = inp_vals[valid_full]
1571+
1572+
# Optionally, flatten to (N, num_patches, patch_size) if needed
1573+
patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous()
1574+
return patches

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2136,3 +2136,173 @@ def test_im2row(
21362136
torch.equal(output, expected_output),
21372137
f"im2row output mismatch in {name}: got {output}, expected {expected_output}",
21382138
)
2139+
2140+
@expand(
2141+
[
2142+
(
2143+
"basic_2x2",
2144+
torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32),
2145+
(2, 2),
2146+
(1, 1),
2147+
(0, 0),
2148+
(1, 1),
2149+
(0, 0),
2150+
None,
2151+
False,
2152+
torch.tensor(
2153+
[
2154+
[
2155+
[1, 0, 0, 0],
2156+
[1, 2, 0, 0],
2157+
[0, 2, 0, 0],
2158+
[1, 0, 3, 0],
2159+
[1, 2, 3, 4],
2160+
[0, 2, 0, 4],
2161+
[0, 0, 3, 0],
2162+
[0, 0, 3, 4],
2163+
[0, 0, 0, 4],
2164+
]
2165+
],
2166+
dtype=torch.int32,
2167+
),
2168+
),
2169+
(
2170+
"basic_2x2_with_zero_point",
2171+
torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32),
2172+
(2, 2),
2173+
(1, 1),
2174+
(0, 0),
2175+
(1, 1),
2176+
(0, 0),
2177+
torch.tensor(100, dtype=torch.int32),
2178+
False,
2179+
torch.tensor(
2180+
[
2181+
[
2182+
[1, 100, 100, 100],
2183+
[1, 2, 100, 100],
2184+
[100, 2, 100, 100],
2185+
[1, 100, 3, 100],
2186+
[1, 2, 3, 4],
2187+
[100, 2, 100, 4],
2188+
[100, 100, 3, 100],
2189+
[100, 100, 3, 4],
2190+
[100, 100, 100, 4],
2191+
]
2192+
],
2193+
dtype=torch.int32,
2194+
),
2195+
),
2196+
(
2197+
"basic_2x2_with_stride_2",
2198+
torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32),
2199+
(2, 2), # kernel size
2200+
(1, 1), # dilation
2201+
(0, 0), # padding
2202+
(2, 2), # stride
2203+
(0, 0), # output padding
2204+
None,
2205+
False,
2206+
torch.tensor(
2207+
[
2208+
[
2209+
[1, 0, 0, 0],
2210+
[1, 0, 0, 0],
2211+
[0, 2, 0, 0],
2212+
[0, 2, 0, 0],
2213+
[1, 0, 0, 0],
2214+
[1, 0, 0, 0],
2215+
[0, 2, 0, 0],
2216+
[0, 2, 0, 0],
2217+
[0, 0, 3, 0],
2218+
[0, 0, 3, 0],
2219+
[0, 0, 0, 4],
2220+
[0, 0, 0, 4],
2221+
[0, 0, 3, 0],
2222+
[0, 0, 3, 0],
2223+
[0, 0, 0, 4],
2224+
[0, 0, 0, 4],
2225+
]
2226+
],
2227+
dtype=torch.int32,
2228+
),
2229+
),
2230+
(
2231+
"batch2_with_batch2_zero_point",
2232+
torch.tensor(
2233+
[
2234+
[[[1, 2], [3, 4]]],
2235+
[[[5, 6], [7, 8]]],
2236+
],
2237+
dtype=torch.int32,
2238+
), # input: (2,1,2,2)
2239+
(2, 2), # kernel_size
2240+
(1, 1), # dilation
2241+
(0, 0), # padding
2242+
(1, 1), # stride
2243+
(0, 0), # output_padding
2244+
torch.tensor([100, 200], dtype=torch.int32), # in_zero_point per batch
2245+
False, # channel_last
2246+
torch.tensor(
2247+
[
2248+
[
2249+
[1, 100, 100, 100],
2250+
[1, 2, 100, 100],
2251+
[100, 2, 100, 100],
2252+
[1, 100, 3, 100],
2253+
[1, 2, 3, 4],
2254+
[100, 2, 100, 4],
2255+
[100, 100, 3, 100],
2256+
[100, 100, 3, 4],
2257+
[100, 100, 100, 4],
2258+
],
2259+
[
2260+
[5, 200, 200, 200],
2261+
[5, 6, 200, 200],
2262+
[200, 6, 200, 200],
2263+
[5, 200, 7, 200],
2264+
[5, 6, 7, 8],
2265+
[200, 6, 200, 8],
2266+
[200, 200, 7, 200],
2267+
[200, 200, 7, 8],
2268+
[200, 200, 200, 8],
2269+
],
2270+
],
2271+
dtype=torch.int32,
2272+
),
2273+
),
2274+
]
2275+
)
2276+
def test_transposed_im2row(
2277+
self,
2278+
name: str,
2279+
input_tensor: torch.Tensor,
2280+
kernel_size: tuple[int, int],
2281+
dilation: tuple[int, int],
2282+
padding: tuple[int, int],
2283+
stride: tuple[int, int],
2284+
output_padding: tuple[int, int],
2285+
in_zero_point: torch.Tensor | int | None,
2286+
channel_last: bool,
2287+
expected_output: torch.Tensor,
2288+
) -> None:
2289+
output = torch.ops.cadence.transposed_im2row(
2290+
input_tensor,
2291+
kernel_size,
2292+
dilation,
2293+
padding,
2294+
stride,
2295+
output_padding,
2296+
in_zero_point,
2297+
channel_last,
2298+
)
2299+
2300+
self.assertEqual(
2301+
output.shape,
2302+
expected_output.shape,
2303+
f"transposed_im2row output shape mismatch in {name}: got {output.shape}, expected {expected_output.shape}",
2304+
)
2305+
self.assertTrue(
2306+
torch.equal(output, expected_output),
2307+
f"transposed_im2row output mismatch in {name}: got {output}, expected {expected_output}",
2308+
)

0 commit comments

Comments
 (0)