@@ -1303,3 +1303,116 @@ def rope(
1303
1303
[x0 * cos_tensor - x1 * sin_tensor , x0 * sin_tensor + x1 * cos_tensor ], dim = - 1
1304
1304
)
1305
1305
return rotated .view (original_shape )
1306
+
1307
+
1308
+ @impl (m , "im2row" )
1309
+ def im2row (
1310
+ input_tensor : torch .Tensor ,
1311
+ kernel_size : tuple [int , int ],
1312
+ dilation : tuple [int , int ],
1313
+ padding : tuple [int , int ],
1314
+ stride : tuple [int , int ],
1315
+ in_zero_point : torch .Tensor ,
1316
+ channel_last : bool = False ,
1317
+ ) -> torch .Tensor :
1318
+ """
1319
+ Converts an input tensor into a 2D matrix where each row is a flattened sliding window (patch)
1320
+ from the input, suitable for use in convolution as a matrix multiplication (im2row).
1321
+
1322
+ Args:
1323
+ - input_tensor: Input tensor of shape (N, C, H, W) or (N, H, W, C) if channel_last.
1324
+ - kernel_size: Size of the convolution kernel.
1325
+ - dilation: Dilation of the convolution kernel.
1326
+ - padding: Padding to apply to the input.
1327
+ - stride: Stride of the convolution.
1328
+ - in_zero_point : Zero point for input quantization (broadcastable to input).
1329
+ - channel_last: If True, input is in NHWC format, else NCHW.
1330
+
1331
+ Returns:
1332
+ - Tensor of shape (N, num_patches, patch_size)
1333
+ """
1334
+ if len (input_tensor .shape ) == 3 :
1335
+ height_dim = 1 if channel_last else 2
1336
+ input_tensor = input_tensor .unsqueeze (height_dim )
1337
+
1338
+ if in_zero_point is not None :
1339
+ if in_zero_point .numel () != 1 and in_zero_point .shape != (
1340
+ input_tensor .shape [0 ],
1341
+ ):
1342
+ raise ValueError (
1343
+ f"Input zero point must be a scalar or broadcastable to input shape { input_tensor .shape } "
1344
+ )
1345
+ if in_zero_point .dtype != torch .int32 :
1346
+ raise ValueError ("Input zero point must be an int32 tensor" )
1347
+
1348
+ if channel_last :
1349
+ input_tensor = input_tensor .movedim (- 1 , - 3 ).contiguous () # NHWC -> NCHW
1350
+
1351
+ N , C , H , W = input_tensor .shape
1352
+ kH , kW = kernel_size
1353
+ dH , dW = dilation
1354
+ pH , pW = padding
1355
+ sH , sW = stride
1356
+
1357
+ # Handle padding with zero point values
1358
+ if in_zero_point is not None and (pH > 0 or pW > 0 ):
1359
+ # Expand zero point to (N, 1, 1, 1) for broadcasting
1360
+ in_zero_point = in_zero_point .expand (N )
1361
+
1362
+ # Pad input with the per-batch zero point values
1363
+ input_tensor = torch .stack (
1364
+ [
1365
+ torch .nn .functional .pad (
1366
+ input_tensor [i ],
1367
+ (pW , pW , pH , pH ),
1368
+ mode = "constant" ,
1369
+ value = in_zero_point [i ].item (),
1370
+ )
1371
+ for i in range (len (input_tensor ))
1372
+ ]
1373
+ )
1374
+
1375
+ padding = (0 , 0 ) # Already padded manually
1376
+
1377
+ # Use unfold to extract sliding local blocks
1378
+ # Unfold: (N, C, H, W) -> (N, C, L, kH, kW), where L = number of sliding windows
1379
+ # torch.nn.functional.unfold returns (N, C*kH*kW, L)
1380
+ patches = torch .nn .functional .unfold (
1381
+ input_tensor .float (), # unfold not implemented for int
1382
+ kernel_size = (kH , kW ),
1383
+ dilation = (dH , dW ),
1384
+ padding = padding ,
1385
+ stride = (sH , sW ),
1386
+ ).to (
1387
+ input_tensor .dtype
1388
+ ) # (N, C*kH*kW, L)
1389
+
1390
+ # Transpose to (N, L, C*kH*kW)
1391
+ patches = patches .transpose (1 , 2 ).contiguous ()
1392
+
1393
+ # Reshape to (N*L, C*kH*kW)
1394
+ patches = patches .view (N , - 1 , C * kH * kW )
1395
+
1396
+ # If channel_last, output should be in NHWC patch order (but im2row is always row-major)
1397
+ return patches
1398
+
1399
+
1400
+ @impl (m , "im2row.per_tensor" )
1401
+ def im2row_per_tensor (
1402
+ input_tensor : torch .Tensor ,
1403
+ kernel_size : tuple [int , int ],
1404
+ dilation : tuple [int , int ],
1405
+ padding : tuple [int , int ],
1406
+ stride : tuple [int , int ],
1407
+ in_zero_point : int ,
1408
+ channel_last : bool = False ,
1409
+ ) -> torch .Tensor :
1410
+ return im2row (
1411
+ input_tensor ,
1412
+ kernel_size ,
1413
+ dilation ,
1414
+ padding ,
1415
+ stride ,
1416
+ torch .tensor (in_zero_point , dtype = torch .int32 ),
1417
+ channel_last ,
1418
+ )
0 commit comments