|
29 | 29 | from ...models.modeling_utils import QuantConfig |
30 | 30 | from ..cublaslt_utils import IS_CUBLASLT_AVAILABLE |
31 | 31 | from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE |
32 | | -from ..utils import Fp4QuantizedTensor |
| 32 | +from ..utils import Fp4QuantizedTensor, unswizzle_sf |
33 | 33 |
|
34 | 34 |
|
35 | 35 | class WeightMode(str, enum.Enum): |
@@ -824,6 +824,9 @@ def apply(self, module: Linear, input: torch.Tensor, |
824 | 824 | act_sf, |
825 | 825 | module.weight_scale, |
826 | 826 | module.alpha, module.dtype) |
| 827 | + # Take the dim of out_features if padded. |
| 828 | + if output.shape[-1] > module.out_features: |
| 829 | + output = output[..., :module.out_features] |
827 | 830 |
|
828 | 831 | if bias is not None: |
829 | 832 | output = output + bias |
@@ -957,6 +960,48 @@ def load_weights_fused_gate_up_linear(self, module: Linear, |
957 | 960 | copy_weight(module.alpha, alpha) |
958 | 961 | module.scalar_alpha = alpha.item() |
959 | 962 |
|
| 963 | + def post_load_weights(self, module: Linear): |
| 964 | + super().post_load_weights(module) |
| 965 | + """ |
| 966 | + Pad weight and weight_scale tensors to meet torch trtllm NVFP4 GEMM alignment requirements. |
| 967 | +
|
| 968 | + Args: |
| 969 | + row_alignment: Required row alignment (default: 32) |
| 970 | + col_alignment: Required column alignment (default: 16) |
| 971 | + """ |
| 972 | + row_alignment, col_alignment = 32, 16 |
| 973 | + row_pad_size = (row_alignment - module.weight.size(0)) % row_alignment |
| 974 | + col_pad_size = (col_alignment - module.weight.size(1)) % col_alignment |
| 975 | + if row_pad_size != 0 or col_pad_size != 0: |
| 976 | + # Pad weight to meet NVFP4 GEMM kernel alignment requirements |
| 977 | + module.weight = Parameter(F.pad(module.weight, |
| 978 | + (0, col_pad_size, 0, row_pad_size), |
| 979 | + mode='constant', |
| 980 | + value=0), |
| 981 | + requires_grad=False) |
| 982 | + weight_col_size = module.weight.size(1) |
| 983 | + assert ( |
| 984 | + weight_col_size * 2 |
| 985 | + ) % module.scaling_vector_size == 0, f"weight column size after padding {weight_col_size} must be divisible by scaling_vector_size {module.scaling_vector_size}" |
| 986 | + # Pad weight_scale to match padded weight dimensions |
| 987 | + # Padding should be performed on unswizzled weight_scale tensor |
| 988 | + scale_rows = fp4_utils.pad_up(module.out_features, 128) |
| 989 | + scale_cols = fp4_utils.pad_up( |
| 990 | + module.in_features // module.scaling_vector_size, 4) |
| 991 | + weight_scale_unswizzle = unswizzle_sf(module.weight_scale.data, |
| 992 | + scale_rows, scale_cols, |
| 993 | + module.scaling_vector_size) |
| 994 | + weight_scale_unswizzle_pad = F.pad( |
| 995 | + weight_scale_unswizzle, |
| 996 | + (0, (col_pad_size * 2) // module.scaling_vector_size, 0, |
| 997 | + row_pad_size), |
| 998 | + mode='constant', |
| 999 | + value=0) |
| 1000 | + module.weight_scale = Parameter( |
| 1001 | + torch.ops.trtllm.block_scale_interleave( |
| 1002 | + weight_scale_unswizzle_pad), |
| 1003 | + requires_grad=False) |
| 1004 | + |
960 | 1005 |
|
961 | 1006 | class W4A8NVFP4FP8LinearMethod(LinearMethodBase): |
962 | 1007 |
|
|
0 commit comments