@@ -1416,3 +1416,159 @@ def im2row_per_tensor(
1416
1416
torch .tensor (in_zero_point , dtype = torch .int32 ),
1417
1417
channel_last ,
1418
1418
)
1419
+
1420
+
1421
+ @impl (m , "transposed_im2row" )
1422
+ def transposed_im2row (
1423
+ input_tensor : torch .Tensor ,
1424
+ kernel_size : tuple [int , int ],
1425
+ dilation : tuple [int , int ],
1426
+ padding : tuple [int , int ],
1427
+ stride : tuple [int , int ],
1428
+ output_padding : tuple [int , int ],
1429
+ in_zero_point : torch .Tensor ,
1430
+ channel_last : bool = False ,
1431
+ ) -> torch .Tensor :
1432
+ """
1433
+ Converts input tensor patches into im2row format for transposed convolutions.
1434
+ This function extracts patches from input in a pattern suitable for transposed convolution.
1435
+
1436
+ Args:
1437
+ - input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D).
1438
+ - kernel_size: Size of the convolution kernel.
1439
+ - dilation: Dilation of the convolution kernel.
1440
+ - padding: Padding to apply to the input.
1441
+ - stride: Stride of the convolution.
1442
+ - output_padding: Additional output padding for transposed convolution.
1443
+ - in_zero_point: Zero point for input quantization (broadcastable to input).
1444
+ - channel_last: If True, input is in NHWC format, else NCHW.
1445
+
1446
+ Returns:
1447
+ - 3D tensor of shape (N, output_h * output_w, kernel_h * kernel_w * in_c)
1448
+ """
1449
+ # Handle 1D convolution case by adding height dimension
1450
+ if len (input_tensor .shape ) == 3 :
1451
+ height_dim = 1 if channel_last else 2
1452
+ input_tensor = input_tensor .unsqueeze (height_dim )
1453
+
1454
+ if in_zero_point is not None :
1455
+ if in_zero_point .dtype != torch .int32 :
1456
+ raise ValueError ("Input zero point must be an int32 tensor" )
1457
+
1458
+ # Move to NCHW for processing if needed
1459
+ if channel_last :
1460
+ input_tensor = input_tensor .movedim (- 1 , - 3 ).contiguous () # NHWC -> NCHW
1461
+
1462
+ N , C , H_in , W_in = input_tensor .shape
1463
+
1464
+ # Output: (N, C*H_in*W_in, H_out, W_out)
1465
+ H_out = (
1466
+ (H_in - 1 ) * stride [0 ]
1467
+ + kernel_size [0 ]
1468
+ + output_padding [0 ]
1469
+ - 2 * padding [0 ]
1470
+ + dilation [0 ] * (kernel_size [0 ] - 1 )
1471
+ )
1472
+ W_out = (
1473
+ (W_in - 1 ) * stride [1 ]
1474
+ + kernel_size [1 ]
1475
+ + output_padding [1 ]
1476
+ - 2 * padding [1 ]
1477
+ + dilation [1 ] * (kernel_size [1 ] - 1 )
1478
+ )
1479
+
1480
+ # For each input pixel, create a channel where the upsampled (transposed conv) patch is placed
1481
+ # Output: (N, C*H_in*W_in, H_out, W_out)
1482
+ inp_flat = input_tensor .reshape (N , C * H_in * W_in )
1483
+
1484
+ # Calculate output spatial size
1485
+ H_out = (
1486
+ (H_in - 1 ) * stride [0 ]
1487
+ - 2 * padding [0 ]
1488
+ + dilation [0 ] * (kernel_size [0 ] - 1 )
1489
+ + output_padding [0 ]
1490
+ + 1
1491
+ )
1492
+ W_out = (
1493
+ (W_in - 1 ) * stride [1 ]
1494
+ - 2 * padding [1 ]
1495
+ + dilation [1 ] * (kernel_size [1 ] - 1 )
1496
+ + output_padding [1 ]
1497
+ + 1
1498
+ )
1499
+
1500
+ # Compute the upsampled (top-left) position for each input pixel
1501
+ h_idx = torch .arange (H_in , device = input_tensor .device )
1502
+ w_idx = torch .arange (W_in , device = input_tensor .device )
1503
+ grid_h , grid_w = torch .meshgrid (h_idx , w_idx , indexing = "ij" )
1504
+ out_h_idx = grid_h * stride [0 ] - padding [0 ]
1505
+ out_w_idx = grid_w * stride [1 ] - padding [1 ]
1506
+
1507
+ # Compute all input pixel positions (flattened)
1508
+ ch_idx = torch .arange (C * H_in * W_in , device = input_tensor .device )
1509
+ ij_idx = ch_idx % (H_in * W_in )
1510
+ i_idx = ij_idx // W_in
1511
+ j_idx = ij_idx % W_in
1512
+
1513
+ # For each input pixel, compute the output positions for the kernel window
1514
+ kh_idx = torch .arange (kernel_size [0 ], device = input_tensor .device )
1515
+ kw_idx = torch .arange (kernel_size [1 ], device = input_tensor .device )
1516
+ kh_grid , kw_grid = torch .meshgrid (kh_idx , kw_idx , indexing = "ij" )
1517
+ kh_grid = kh_grid .reshape (- 1 )
1518
+ kw_grid = kw_grid .reshape (- 1 )
1519
+ num_kernel = kernel_size [0 ] * kernel_size [1 ]
1520
+
1521
+ # Broadcast to all channels and kernel positions
1522
+ ch_idx_b = ch_idx .repeat_interleave (num_kernel )
1523
+ n_kernel = ch_idx .shape [0 ] * num_kernel
1524
+
1525
+ i_idx_b = i_idx .repeat_interleave (num_kernel )
1526
+ j_idx_b = j_idx .repeat_interleave (num_kernel )
1527
+ kh_b = kh_grid .repeat (ch_idx .shape [0 ])
1528
+ kw_b = kw_grid .repeat (ch_idx .shape [0 ])
1529
+
1530
+ h_out = out_h_idx [i_idx_b , j_idx_b ] + kh_b * dilation [0 ]
1531
+ w_out = out_w_idx [i_idx_b , j_idx_b ] + kw_b * dilation [1 ]
1532
+
1533
+ # Mask for valid output positions
1534
+ valid = (h_out >= 0 ) & (h_out < H_out ) & (w_out >= 0 ) & (w_out < W_out )
1535
+
1536
+ # Prepare indices for advanced indexing
1537
+ n_idx = (
1538
+ torch .arange (N , device = input_tensor .device )
1539
+ .view (- 1 , 1 )
1540
+ .expand (N , n_kernel )
1541
+ .reshape (- 1 )
1542
+ )
1543
+ ch_idx_full = ch_idx_b .expand (N , n_kernel ).reshape (- 1 )
1544
+ h_out_full = h_out .expand (N , n_kernel ).reshape (- 1 )
1545
+ w_out_full = w_out .expand (N , n_kernel ).reshape (- 1 )
1546
+ valid_full = valid .expand (N , n_kernel ).reshape (- 1 )
1547
+
1548
+ # Gather input values for each channel
1549
+ inp_vals = inp_flat [:, ch_idx_b ].reshape (- 1 )
1550
+
1551
+ # Create output tensor
1552
+ patches = torch .zeros ((N , C * H_in * W_in , H_out , W_out ), dtype = input_tensor .dtype )
1553
+
1554
+ # If in_zero_point is provided, fill patches with it
1555
+ if in_zero_point is not None :
1556
+ if in_zero_point .numel () == 1 :
1557
+ patches .fill_ (in_zero_point .item ())
1558
+ else :
1559
+ # Broadcast in_zero_point to (N, C, H_in, W_in)
1560
+ assert in_zero_point .shape == (N ,)
1561
+ in_zero_point = in_zero_point .view (N , 1 , 1 , 1 )
1562
+ patches = patches + in_zero_point
1563
+
1564
+ # Scatter input values to output positions (only valid positions)
1565
+ patches [
1566
+ n_idx [valid_full ],
1567
+ ch_idx_full [valid_full ],
1568
+ h_out_full [valid_full ],
1569
+ w_out_full [valid_full ],
1570
+ ] = inp_vals [valid_full ]
1571
+
1572
+ # Optionally, flatten to (N, num_patches, patch_size) if needed
1573
+ patches = patches .view (N , C * H_in * W_in , - 1 ).transpose (1 , 2 ).contiguous ()
1574
+ return patches
0 commit comments