Skip to content

Commit 7be1cc2

Browse files
committed
Update base for Update on "Introduce public MergedDataMap"
Add public merged data map. Module can use this to resolve multiple named data maps. Differential Revision: [D83527299](https://our.internmc.facebook.com/intern/diff/D83527299/) [ghstack-poisoned]
2 parents 7411ce1 + 7c7b729 commit 7be1cc2

File tree

22 files changed

+840
-37
lines changed

22 files changed

+840
-37
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,3 +1303,116 @@ def rope(
13031303
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
13041304
)
13051305
return rotated.view(original_shape)
1306+
1307+
1308+
@impl(m, "im2row")
1309+
def im2row(
1310+
input_tensor: torch.Tensor,
1311+
kernel_size: tuple[int, int],
1312+
dilation: tuple[int, int],
1313+
padding: tuple[int, int],
1314+
stride: tuple[int, int],
1315+
in_zero_point: torch.Tensor,
1316+
channel_last: bool = False,
1317+
) -> torch.Tensor:
1318+
"""
1319+
Converts an input tensor into a 2D matrix where each row is a flattened sliding window (patch)
1320+
from the input, suitable for use in convolution as a matrix multiplication (im2row).
1321+
1322+
Args:
1323+
- input_tensor: Input tensor of shape (N, C, H, W) or (N, H, W, C) if channel_last.
1324+
- kernel_size: Size of the convolution kernel.
1325+
- dilation: Dilation of the convolution kernel.
1326+
- padding: Padding to apply to the input.
1327+
- stride: Stride of the convolution.
1328+
- in_zero_point : Zero point for input quantization (broadcastable to input).
1329+
- channel_last: If True, input is in NHWC format, else NCHW.
1330+
1331+
Returns:
1332+
- Tensor of shape (N, num_patches, patch_size)
1333+
"""
1334+
if len(input_tensor.shape) == 3:
1335+
height_dim = 1 if channel_last else 2
1336+
input_tensor = input_tensor.unsqueeze(height_dim)
1337+
1338+
if in_zero_point is not None:
1339+
if in_zero_point.numel() != 1 and in_zero_point.shape != (
1340+
input_tensor.shape[0],
1341+
):
1342+
raise ValueError(
1343+
f"Input zero point must be a scalar or broadcastable to input shape {input_tensor.shape}"
1344+
)
1345+
if in_zero_point.dtype != torch.int32:
1346+
raise ValueError("Input zero point must be an int32 tensor")
1347+
1348+
if channel_last:
1349+
input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW
1350+
1351+
N, C, H, W = input_tensor.shape
1352+
kH, kW = kernel_size
1353+
dH, dW = dilation
1354+
pH, pW = padding
1355+
sH, sW = stride
1356+
1357+
# Handle padding with zero point values
1358+
if in_zero_point is not None and (pH > 0 or pW > 0):
1359+
# Expand zero point to (N, 1, 1, 1) for broadcasting
1360+
in_zero_point = in_zero_point.expand(N)
1361+
1362+
# Pad input with the per-batch zero point values
1363+
input_tensor = torch.stack(
1364+
[
1365+
torch.nn.functional.pad(
1366+
input_tensor[i],
1367+
(pW, pW, pH, pH),
1368+
mode="constant",
1369+
value=in_zero_point[i].item(),
1370+
)
1371+
for i in range(len(input_tensor))
1372+
]
1373+
)
1374+
1375+
padding = (0, 0) # Already padded manually
1376+
1377+
# Use unfold to extract sliding local blocks
1378+
# Unfold: (N, C, H, W) -> (N, C, L, kH, kW), where L = number of sliding windows
1379+
# torch.nn.functional.unfold returns (N, C*kH*kW, L)
1380+
patches = torch.nn.functional.unfold(
1381+
input_tensor.float(), # unfold not implemented for int
1382+
kernel_size=(kH, kW),
1383+
dilation=(dH, dW),
1384+
padding=padding,
1385+
stride=(sH, sW),
1386+
).to(
1387+
input_tensor.dtype
1388+
) # (N, C*kH*kW, L)
1389+
1390+
# Transpose to (N, L, C*kH*kW)
1391+
patches = patches.transpose(1, 2).contiguous()
1392+
1393+
# Reshape to (N*L, C*kH*kW)
1394+
patches = patches.view(N, -1, C * kH * kW)
1395+
1396+
# If channel_last, output should be in NHWC patch order (but im2row is always row-major)
1397+
return patches
1398+
1399+
1400+
@impl(m, "im2row.per_tensor")
1401+
def im2row_per_tensor(
1402+
input_tensor: torch.Tensor,
1403+
kernel_size: tuple[int, int],
1404+
dilation: tuple[int, int],
1405+
padding: tuple[int, int],
1406+
stride: tuple[int, int],
1407+
in_zero_point: int,
1408+
channel_last: bool = False,
1409+
) -> torch.Tensor:
1410+
return im2row(
1411+
input_tensor,
1412+
kernel_size,
1413+
dilation,
1414+
padding,
1415+
stride,
1416+
torch.tensor(in_zero_point, dtype=torch.int32),
1417+
channel_last,
1418+
)

0 commit comments

Comments
 (0)