@@ -1303,3 +1303,116 @@ def rope(
13031303 [x0 * cos_tensor - x1 * sin_tensor , x0 * sin_tensor + x1 * cos_tensor ], dim = - 1
13041304 )
13051305 return rotated .view (original_shape )
1306+
1307+
1308+ @impl (m , "im2row" )
1309+ def im2row (
1310+ input_tensor : torch .Tensor ,
1311+ kernel_size : tuple [int , int ],
1312+ dilation : tuple [int , int ],
1313+ padding : tuple [int , int ],
1314+ stride : tuple [int , int ],
1315+ in_zero_point : torch .Tensor ,
1316+ channel_last : bool = False ,
1317+ ) -> torch .Tensor :
1318+ """
1319+ Converts an input tensor into a 2D matrix where each row is a flattened sliding window (patch)
1320+ from the input, suitable for use in convolution as a matrix multiplication (im2row).
1321+
1322+ Args:
1323+ - input_tensor: Input tensor of shape (N, C, H, W) or (N, H, W, C) if channel_last.
1324+ - kernel_size: Size of the convolution kernel.
1325+ - dilation: Dilation of the convolution kernel.
1326+ - padding: Padding to apply to the input.
1327+ - stride: Stride of the convolution.
1328+ - in_zero_point : Zero point for input quantization (broadcastable to input).
1329+ - channel_last: If True, input is in NHWC format, else NCHW.
1330+
1331+ Returns:
1332+ - Tensor of shape (N, num_patches, patch_size)
1333+ """
1334+ if len (input_tensor .shape ) == 3 :
1335+ height_dim = 1 if channel_last else 2
1336+ input_tensor = input_tensor .unsqueeze (height_dim )
1337+
1338+ if in_zero_point is not None :
1339+ if in_zero_point .numel () != 1 and in_zero_point .shape != (
1340+ input_tensor .shape [0 ],
1341+ ):
1342+ raise ValueError (
1343+ f"Input zero point must be a scalar or broadcastable to input shape { input_tensor .shape } "
1344+ )
1345+ if in_zero_point .dtype != torch .int32 :
1346+ raise ValueError ("Input zero point must be an int32 tensor" )
1347+
1348+ if channel_last :
1349+ input_tensor = input_tensor .movedim (- 1 , - 3 ).contiguous () # NHWC -> NCHW
1350+
1351+ N , C , H , W = input_tensor .shape
1352+ kH , kW = kernel_size
1353+ dH , dW = dilation
1354+ pH , pW = padding
1355+ sH , sW = stride
1356+
1357+ # Handle padding with zero point values
1358+ if in_zero_point is not None and (pH > 0 or pW > 0 ):
1359+ # Expand zero point to (N, 1, 1, 1) for broadcasting
1360+ in_zero_point = in_zero_point .expand (N )
1361+
1362+ # Pad input with the per-batch zero point values
1363+ input_tensor = torch .stack (
1364+ [
1365+ torch .nn .functional .pad (
1366+ input_tensor [i ],
1367+ (pW , pW , pH , pH ),
1368+ mode = "constant" ,
1369+ value = in_zero_point [i ].item (),
1370+ )
1371+ for i in range (len (input_tensor ))
1372+ ]
1373+ )
1374+
1375+ padding = (0 , 0 ) # Already padded manually
1376+
1377+ # Use unfold to extract sliding local blocks
1378+ # Unfold: (N, C, H, W) -> (N, C, L, kH, kW), where L = number of sliding windows
1379+ # torch.nn.functional.unfold returns (N, C*kH*kW, L)
1380+ patches = torch .nn .functional .unfold (
1381+ input_tensor .float (), # unfold not implemented for int
1382+ kernel_size = (kH , kW ),
1383+ dilation = (dH , dW ),
1384+ padding = padding ,
1385+ stride = (sH , sW ),
1386+ ).to (
1387+ input_tensor .dtype
1388+ ) # (N, C*kH*kW, L)
1389+
1390+ # Transpose to (N, L, C*kH*kW)
1391+ patches = patches .transpose (1 , 2 ).contiguous ()
1392+
1393+ # Reshape to (N*L, C*kH*kW)
1394+ patches = patches .view (N , - 1 , C * kH * kW )
1395+
1396+ # If channel_last, output should be in NHWC patch order (but im2row is always row-major)
1397+ return patches
1398+
1399+
1400+ @impl (m , "im2row.per_tensor" )
1401+ def im2row_per_tensor (
1402+ input_tensor : torch .Tensor ,
1403+ kernel_size : tuple [int , int ],
1404+ dilation : tuple [int , int ],
1405+ padding : tuple [int , int ],
1406+ stride : tuple [int , int ],
1407+ in_zero_point : int ,
1408+ channel_last : bool = False ,
1409+ ) -> torch .Tensor :
1410+ return im2row (
1411+ input_tensor ,
1412+ kernel_size ,
1413+ dilation ,
1414+ padding ,
1415+ stride ,
1416+ torch .tensor (in_zero_point , dtype = torch .int32 ),
1417+ channel_last ,
1418+ )
0 commit comments