diff --git a/tests/test_trtllm_cutlass_fused_moe.py b/tests/test_trtllm_cutlass_fused_moe.py index 5f9f04dc63..84b2eddae7 100644 --- a/tests/test_trtllm_cutlass_fused_moe.py +++ b/tests/test_trtllm_cutlass_fused_moe.py @@ -215,7 +215,7 @@ def compute_with_experts( 1, ] HIDDEN_SIZES = [ - 128, + 256, ] NUM_EXPERTS = [2] TOP_K_VALUES = [2] @@ -884,6 +884,7 @@ def dequantize_block( scales: torch.Tensor, dtype: torch.dtype, original_shape: tuple, + block_size_n: int = 128, ) -> torch.Tensor: """ Dequantize a block-quantized tensor. @@ -893,6 +894,7 @@ def dequantize_block( scales: Block scaling factors dtype: Target dtype for dequantization original_shape: Original shape of the tensor before padding + block_size_n: Block size Returns: torch.Tensor: Dequantized tensor @@ -904,8 +906,8 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: if dim != -1: a = a.transpose(dim, -1) # Broadcast and reshape - a_broadcasted = a.unsqueeze(-1).expand(*a.shape, 128) - a_reshaped = a_broadcasted.reshape(*a.shape[:-1], a.shape[-1] * 128) + a_broadcasted = a.unsqueeze(-1).expand(*a.shape, block_size_n) + a_reshaped = a_broadcasted.reshape(*a.shape[:-1], a.shape[-1] * block_size_n) # Move back if needed if dim != -1: a_reshaped = a_reshaped.transpose(dim, -1) @@ -913,9 +915,13 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: if x_quant.dim() == 2: # For activation tensors [batch_size, hidden_size] batch_size, hidden_size = x_quant.shape - num_blocks = (hidden_size + 127) // 128 - scales = scales.view(batch_size, num_blocks, 1).expand(-1, -1, 128) - scales = scales[:, :, : hidden_size % 128] if hidden_size % 128 != 0 else scales + num_blocks = ceil_div(hidden_size, block_size_n) + scales = ( + scales.view(batch_size, num_blocks, 1) + .expand(-1, -1, block_size_n) + .reshape(batch_size, -1) + ) + scales = scales[:, :hidden_size] else: # For weight tensors [..., in_dim, out_dim] *_dims, in_dim, out_dim = x_quant.shape @@ -924,10 +930,10 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: scales = transform_dim(scales, -2) # Second-to-last dim # Handle padding - if in_dim % 128 != 0: - scales = scales[..., : in_dim % 128, :] - if out_dim % 128 != 0: - scales = scales[..., :, : out_dim % 128] + if in_dim % block_size_n != 0: + scales = scales[..., : in_dim % block_size_n, :] + if out_dim % block_size_n != 0: + scales = scales[..., :, : out_dim % block_size_n] x_dequant = x_quant.to(dtype) * scales.to(dtype) return x_dequant.view(original_shape)