@@ -106,13 +106,13 @@ def create_quantized_param(
106106
107107 if rows % block_size_m != 0 or cols % block_size_n != 0 :
108108 raise ValueError (
109- f"Matrix dimensions ({ rows } , { cols } ) must be divisible by block sizes ({ block_size_m } , { block_size_n } )"
109+ f"Matrix dimensions ({ rows } , { cols } ) must be divisible by block sizes ({ block_size_m } , { block_size_n } for { param_name } )"
110110 )
111111 param_value_orig_shape = param_value .shape
112112
113113 param_value = param_value .reshape (
114- - 1 , rows // block_size_m , block_size_m , cols // block_size_n , block_size_n
115- ).permute (0 , 1 , 3 , 2 , 4 )
114+ rows // block_size_m , block_size_m , cols // block_size_n , block_size_n
115+ ).permute (0 , 2 , 1 , 3 )
116116
117117 # Calculate scaling factor for each block
118118 max_abs = torch .amax (torch .abs (param_value ), dim = (- 1 , - 2 ))
@@ -123,12 +123,12 @@ def create_quantized_param(
123123 # Quantize the weights
124124 quantized_param = torch .clamp (param_value * scale , min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
125125
126- quantized_param = quantized_param .permute (0 , 1 , 3 , 2 , 4 )
126+ quantized_param = quantized_param .permute (0 , 2 , 1 , 3 )
127127 # Reshape back to matrix shape
128128 quantized_param = quantized_param .reshape (param_value_orig_shape )
129129
130130 # Reshape scale to match the number of blocks
131- scale = scale .reshape (scale_orig_shape ).squeeze (). reciprocal ()
131+ scale = scale .reshape (scale_orig_shape ).reciprocal ()
132132
133133 # Load into the model
134134 module ._parameters [tensor_name ] = quantized_param .to (target_device )
0 commit comments