Skip to content

Commit b021fd0

Browse files
authored
Support im2row
Differential Revision: D83620790 Pull Request resolved: #14729
1 parent 7116e0a commit b021fd0

File tree

2 files changed

+406
-0
lines changed

2 files changed

+406
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,3 +1303,116 @@ def rope(
13031303
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
13041304
)
13051305
return rotated.view(original_shape)
1306+
1307+
1308+
@impl(m, "im2row")
1309+
def im2row(
1310+
input_tensor: torch.Tensor,
1311+
kernel_size: tuple[int, int],
1312+
dilation: tuple[int, int],
1313+
padding: tuple[int, int],
1314+
stride: tuple[int, int],
1315+
in_zero_point: torch.Tensor,
1316+
channel_last: bool = False,
1317+
) -> torch.Tensor:
1318+
"""
1319+
Converts an input tensor into a 2D matrix where each row is a flattened sliding window (patch)
1320+
from the input, suitable for use in convolution as a matrix multiplication (im2row).
1321+
1322+
Args:
1323+
- input_tensor: Input tensor of shape (N, C, H, W) or (N, H, W, C) if channel_last.
1324+
- kernel_size: Size of the convolution kernel.
1325+
- dilation: Dilation of the convolution kernel.
1326+
- padding: Padding to apply to the input.
1327+
- stride: Stride of the convolution.
1328+
- in_zero_point : Zero point for input quantization (broadcastable to input).
1329+
- channel_last: If True, input is in NHWC format, else NCHW.
1330+
1331+
Returns:
1332+
- Tensor of shape (N, num_patches, patch_size)
1333+
"""
1334+
if len(input_tensor.shape) == 3:
1335+
height_dim = 1 if channel_last else 2
1336+
input_tensor = input_tensor.unsqueeze(height_dim)
1337+
1338+
if in_zero_point is not None:
1339+
if in_zero_point.numel() != 1 and in_zero_point.shape != (
1340+
input_tensor.shape[0],
1341+
):
1342+
raise ValueError(
1343+
f"Input zero point must be a scalar or broadcastable to input shape {input_tensor.shape}"
1344+
)
1345+
if in_zero_point.dtype != torch.int32:
1346+
raise ValueError("Input zero point must be an int32 tensor")
1347+
1348+
if channel_last:
1349+
input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW
1350+
1351+
N, C, H, W = input_tensor.shape
1352+
kH, kW = kernel_size
1353+
dH, dW = dilation
1354+
pH, pW = padding
1355+
sH, sW = stride
1356+
1357+
# Handle padding with zero point values
1358+
if in_zero_point is not None and (pH > 0 or pW > 0):
1359+
# Expand zero point to (N, 1, 1, 1) for broadcasting
1360+
in_zero_point = in_zero_point.expand(N)
1361+
1362+
# Pad input with the per-batch zero point values
1363+
input_tensor = torch.stack(
1364+
[
1365+
torch.nn.functional.pad(
1366+
input_tensor[i],
1367+
(pW, pW, pH, pH),
1368+
mode="constant",
1369+
value=in_zero_point[i].item(),
1370+
)
1371+
for i in range(len(input_tensor))
1372+
]
1373+
)
1374+
1375+
padding = (0, 0) # Already padded manually
1376+
1377+
# Use unfold to extract sliding local blocks
1378+
# Unfold: (N, C, H, W) -> (N, C, L, kH, kW), where L = number of sliding windows
1379+
# torch.nn.functional.unfold returns (N, C*kH*kW, L)
1380+
patches = torch.nn.functional.unfold(
1381+
input_tensor.float(), # unfold not implemented for int
1382+
kernel_size=(kH, kW),
1383+
dilation=(dH, dW),
1384+
padding=padding,
1385+
stride=(sH, sW),
1386+
).to(
1387+
input_tensor.dtype
1388+
) # (N, C*kH*kW, L)
1389+
1390+
# Transpose to (N, L, C*kH*kW)
1391+
patches = patches.transpose(1, 2).contiguous()
1392+
1393+
# Reshape to (N*L, C*kH*kW)
1394+
patches = patches.view(N, -1, C * kH * kW)
1395+
1396+
# If channel_last, output should be in NHWC patch order (but im2row is always row-major)
1397+
return patches
1398+
1399+
1400+
@impl(m, "im2row.per_tensor")
1401+
def im2row_per_tensor(
1402+
input_tensor: torch.Tensor,
1403+
kernel_size: tuple[int, int],
1404+
dilation: tuple[int, int],
1405+
padding: tuple[int, int],
1406+
stride: tuple[int, int],
1407+
in_zero_point: int,
1408+
channel_last: bool = False,
1409+
) -> torch.Tensor:
1410+
return im2row(
1411+
input_tensor,
1412+
kernel_size,
1413+
dilation,
1414+
padding,
1415+
stride,
1416+
torch.tensor(in_zero_point, dtype=torch.int32),
1417+
channel_last,
1418+
)

0 commit comments

Comments
 (0)