@@ -1416,3 +1416,159 @@ def im2row_per_tensor(
14161416 torch .tensor (in_zero_point , dtype = torch .int32 ),
14171417 channel_last ,
14181418 )
1419+
1420+
1421+ @impl (m , "transposed_im2row" )
1422+ def transposed_im2row (
1423+ input_tensor : torch .Tensor ,
1424+ kernel_size : tuple [int , int ],
1425+ dilation : tuple [int , int ],
1426+ padding : tuple [int , int ],
1427+ stride : tuple [int , int ],
1428+ output_padding : tuple [int , int ],
1429+ in_zero_point : torch .Tensor ,
1430+ channel_last : bool = False ,
1431+ ) -> torch .Tensor :
1432+ """
1433+ Converts input tensor patches into im2row format for transposed convolutions.
1434+ This function extracts patches from input in a pattern suitable for transposed convolution.
1435+
1436+ Args:
1437+ - input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D).
1438+ - kernel_size: Size of the convolution kernel.
1439+ - dilation: Dilation of the convolution kernel.
1440+ - padding: Padding to apply to the input.
1441+ - stride: Stride of the convolution.
1442+ - output_padding: Additional output padding for transposed convolution.
1443+ - in_zero_point: Zero point for input quantization (broadcastable to input).
1444+ - channel_last: If True, input is in NHWC format, else NCHW.
1445+
1446+ Returns:
1447+ - 3D tensor of shape (N, output_h * output_w, kernel_h * kernel_w * in_c)
1448+ """
1449+ # Handle 1D convolution case by adding height dimension
1450+ if len (input_tensor .shape ) == 3 :
1451+ height_dim = 1 if channel_last else 2
1452+ input_tensor = input_tensor .unsqueeze (height_dim )
1453+
1454+ if in_zero_point is not None :
1455+ if in_zero_point .dtype != torch .int32 :
1456+ raise ValueError ("Input zero point must be an int32 tensor" )
1457+
1458+ # Move to NCHW for processing if needed
1459+ if channel_last :
1460+ input_tensor = input_tensor .movedim (- 1 , - 3 ).contiguous () # NHWC -> NCHW
1461+
1462+ N , C , H_in , W_in = input_tensor .shape
1463+
1464+ # Output: (N, C*H_in*W_in, H_out, W_out)
1465+ H_out = (
1466+ (H_in - 1 ) * stride [0 ]
1467+ + kernel_size [0 ]
1468+ + output_padding [0 ]
1469+ - 2 * padding [0 ]
1470+ + dilation [0 ] * (kernel_size [0 ] - 1 )
1471+ )
1472+ W_out = (
1473+ (W_in - 1 ) * stride [1 ]
1474+ + kernel_size [1 ]
1475+ + output_padding [1 ]
1476+ - 2 * padding [1 ]
1477+ + dilation [1 ] * (kernel_size [1 ] - 1 )
1478+ )
1479+
1480+ # For each input pixel, create a channel where the upsampled (transposed conv) patch is placed
1481+ # Output: (N, C*H_in*W_in, H_out, W_out)
1482+ inp_flat = input_tensor .reshape (N , C * H_in * W_in )
1483+
1484+ # Calculate output spatial size
1485+ H_out = (
1486+ (H_in - 1 ) * stride [0 ]
1487+ - 2 * padding [0 ]
1488+ + dilation [0 ] * (kernel_size [0 ] - 1 )
1489+ + output_padding [0 ]
1490+ + 1
1491+ )
1492+ W_out = (
1493+ (W_in - 1 ) * stride [1 ]
1494+ - 2 * padding [1 ]
1495+ + dilation [1 ] * (kernel_size [1 ] - 1 )
1496+ + output_padding [1 ]
1497+ + 1
1498+ )
1499+
1500+ # Compute the upsampled (top-left) position for each input pixel
1501+ h_idx = torch .arange (H_in , device = input_tensor .device )
1502+ w_idx = torch .arange (W_in , device = input_tensor .device )
1503+ grid_h , grid_w = torch .meshgrid (h_idx , w_idx , indexing = "ij" )
1504+ out_h_idx = grid_h * stride [0 ] - padding [0 ]
1505+ out_w_idx = grid_w * stride [1 ] - padding [1 ]
1506+
1507+ # Compute all input pixel positions (flattened)
1508+ ch_idx = torch .arange (C * H_in * W_in , device = input_tensor .device )
1509+ ij_idx = ch_idx % (H_in * W_in )
1510+ i_idx = ij_idx // W_in
1511+ j_idx = ij_idx % W_in
1512+
1513+ # For each input pixel, compute the output positions for the kernel window
1514+ kh_idx = torch .arange (kernel_size [0 ], device = input_tensor .device )
1515+ kw_idx = torch .arange (kernel_size [1 ], device = input_tensor .device )
1516+ kh_grid , kw_grid = torch .meshgrid (kh_idx , kw_idx , indexing = "ij" )
1517+ kh_grid = kh_grid .reshape (- 1 )
1518+ kw_grid = kw_grid .reshape (- 1 )
1519+ num_kernel = kernel_size [0 ] * kernel_size [1 ]
1520+
1521+ # Broadcast to all channels and kernel positions
1522+ ch_idx_b = ch_idx .repeat_interleave (num_kernel )
1523+ n_kernel = ch_idx .shape [0 ] * num_kernel
1524+
1525+ i_idx_b = i_idx .repeat_interleave (num_kernel )
1526+ j_idx_b = j_idx .repeat_interleave (num_kernel )
1527+ kh_b = kh_grid .repeat (ch_idx .shape [0 ])
1528+ kw_b = kw_grid .repeat (ch_idx .shape [0 ])
1529+
1530+ h_out = out_h_idx [i_idx_b , j_idx_b ] + kh_b * dilation [0 ]
1531+ w_out = out_w_idx [i_idx_b , j_idx_b ] + kw_b * dilation [1 ]
1532+
1533+ # Mask for valid output positions
1534+ valid = (h_out >= 0 ) & (h_out < H_out ) & (w_out >= 0 ) & (w_out < W_out )
1535+
1536+ # Prepare indices for advanced indexing
1537+ n_idx = (
1538+ torch .arange (N , device = input_tensor .device )
1539+ .view (- 1 , 1 )
1540+ .expand (N , n_kernel )
1541+ .reshape (- 1 )
1542+ )
1543+ ch_idx_full = ch_idx_b .expand (N , n_kernel ).reshape (- 1 )
1544+ h_out_full = h_out .expand (N , n_kernel ).reshape (- 1 )
1545+ w_out_full = w_out .expand (N , n_kernel ).reshape (- 1 )
1546+ valid_full = valid .expand (N , n_kernel ).reshape (- 1 )
1547+
1548+ # Gather input values for each channel
1549+ inp_vals = inp_flat [:, ch_idx_b ].reshape (- 1 )
1550+
1551+ # Create output tensor
1552+ patches = torch .zeros ((N , C * H_in * W_in , H_out , W_out ), dtype = input_tensor .dtype )
1553+
1554+ # If in_zero_point is provided, fill patches with it
1555+ if in_zero_point is not None :
1556+ if in_zero_point .numel () == 1 :
1557+ patches .fill_ (in_zero_point .item ())
1558+ else :
1559+ # Broadcast in_zero_point to (N, C, H_in, W_in)
1560+ assert in_zero_point .shape == (N ,)
1561+ in_zero_point = in_zero_point .view (N , 1 , 1 , 1 )
1562+ patches = patches + in_zero_point
1563+
1564+ # Scatter input values to output positions (only valid positions)
1565+ patches [
1566+ n_idx [valid_full ],
1567+ ch_idx_full [valid_full ],
1568+ h_out_full [valid_full ],
1569+ w_out_full [valid_full ],
1570+ ] = inp_vals [valid_full ]
1571+
1572+ # Optionally, flatten to (N, num_patches, patch_size) if needed
1573+ patches = patches .view (N , C * H_in * W_in , - 1 ).transpose (1 , 2 ).contiguous ()
1574+ return patches
0 commit comments