diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..90033649 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -61,6 +61,27 @@ def compression_param_names(self) -> Tuple[str]: "weight_global_scale", ) + def compression_param_info( + self, + weight_shape: torch.Size, + quantization_args: Optional[QuantizationArgs] = None, + ) -> Dict[str, Tuple[torch.Size, torch.dtype]]: + """ + Creates a dictionary of expected shapes and dtypes for each compression + parameter used by the compressor + + :param weight_shape: uncompressed weight shape + :param quantization_args: quantization parameters for the weight + :return: dictionary mapping compressed parameter names to shape and dtype + """ + output = { + "weight_packed": ( + torch.Size((weight_shape[0], weight_shape[1] // 2)), + torch.uint8, + ), + } + return output + def compress_weight( self, weight: Tensor, diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index b82a4195..d3c9da40 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -257,13 +257,10 @@ def _process_quantization( QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): - n_dims = x.shape - if len(n_dims) > 2: - x = x.squeeze(0) output_dtype = dtype if dtype is not None else x.dtype output = torch.zeros_like(x).to(output_dtype) - columns = output.shape[1] + columns = output.shape[-1] # TODO: make validation step for inputs @@ -293,14 +290,12 @@ def _process_quantization( perm = torch.argsort(g_idx) x = safe_permute(x, perm, dim=1) - x = torch.reshape( - x, - ( - x.shape[0], - ceil(x.shape[1] / group_size), - group_size, - ), + # Maintain all dimensions apart from the last dim, which is divided by the group_size + reshaped_dims = ( + ceil(x.shape[-1] / group_size), + group_size, ) + x = x.unflatten(-1, reshaped_dims) if do_quantize: output = _quantize( @@ -323,19 +318,12 @@ def _process_quantization( global_scale=global_scale, ) - output = torch.reshape( - output, - (output.shape[0], output.shape[1] * output.shape[2]), - ) - + output = output.flatten(start_dim=-2) output = output.to(output_dtype) if not is_column_order: output = safe_permute(output, torch.argsort(perm), dim=1) - if len(n_dims) > 2: - output = output.unsqueeze(0) - else: # covers channel, token and tensor strategies if do_quantize: output = _quantize( diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 42a6e19e..5d28cac2 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -175,20 +175,16 @@ def compute_dynamic_scales_and_zp( QuantizationStrategy.TENSOR_GROUP, QuantizationStrategy.GROUP, ): - if len(value.shape) > 2: - value = value.squeeze(0) - dim = {0, 1} - reduce_dims = tuple(idx for idx in range(3) if idx not in dim) + reduce_dims = -1 keep_dims = False - value = torch.reshape( - value, - ( - value.shape[0], - math.ceil(value.shape[1] / args.group_size), - args.group_size, - ), + + reshaped_dims = ( + math.ceil(value.shape[-1] / args.group_size), + args.group_size, ) + value = value.unflatten(-1, reshaped_dims) + else: supported_strategies = ( QuantizationStrategy.TOKEN, diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index 2c6b1224..b9f9754c 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -83,7 +83,7 @@ def test_fused_global_scales(): "shape,group_size,exp_shape", [ # Only batch size =1 is supported for dynamic GROUP quantization - ((1, 4, 8), 4, torch.Size([4, 2])), + ((1, 4, 8), 4, torch.Size([1, 4, 2])), ], ) def test_compute_dynamic_scales_and_zp_group(shape, group_size, exp_shape):