1010from bitsandbytes .triton .quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
1111from bitsandbytes .triton .int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
1212from bitsandbytes .triton .quantize_global import quantize_global , quantize_global_transpose
13- from bitsandbytes .triton .int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
13+ from bitsandbytes .triton .int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
1414
1515
1616class _switchback_global (torch .autograd .Function ):
@@ -29,7 +29,7 @@ def forward(ctx, X_3D, W, bias):
2929
3030 # matmult, fused dequant and add bias
3131 # call "mixed" because we are mixing rowwise quantized and global quantized
32- return int8_matmul_mixed_dequanitze (
32+ return int8_matmul_mixed_dequantize (
3333 X_int8 , W_int8 .t (), state_X , state_W , bias
3434 ).view (* X_3D .size ()[:- 1 ], - 1 )
3535
@@ -47,7 +47,7 @@ def backward(ctx, G_3D):
4747 # so we transpose once then call .t() in the matmul
4848 G_int8 , state_G = quantize_rowwise (G )
4949 W_int8 , state_W = quantize_global_transpose (W )
50- grad_X = int8_matmul_mixed_dequanitze (G_int8 , W_int8 .t (), state_G , state_W , None ).view (
50+ grad_X = int8_matmul_mixed_dequantize (G_int8 , W_int8 .t (), state_G , state_W , None ).view (
5151 * G_3D .size ()[:- 1 ], - 1
5252 )
5353 if ctx .needs_input_grad [1 ]:
@@ -119,7 +119,7 @@ def forward(ctx, X_3D, W, bias):
119119
120120 # matmult, fused dequant and add bias
121121 # call "mixed" because we are mixing rowwise quantized and global quantized
122- return int8_matmul_mixed_dequanitze (
122+ return int8_matmul_mixed_dequantize (
123123 X_int8 , W_int8 .t (), state_X , state_W , bias
124124 ).view (* X_3D_sz [:- 1 ], - 1 )
125125
@@ -143,7 +143,7 @@ def backward(ctx, G_3D):
143143 G_int8 , state_G = quantize_rowwise (G )
144144 del G
145145 W_int8 = W_int8 .t ().contiguous ()
146- grad_X = int8_matmul_mixed_dequanitze (G_int8 , W_int8 .t (), state_G , state_W , None ).view (
146+ grad_X = int8_matmul_mixed_dequantize (G_int8 , W_int8 .t (), state_G , state_W , None ).view (
147147 * G_3D_sz [:- 1 ], - 1
148148 )
149149
@@ -215,7 +215,7 @@ def forward(self, x):
215215 X_int8 , self .W_int8 .t (), state_X , self .state_W , self .bias
216216 ).view (* x .size ()[:- 1 ], - 1 )
217217 else :
218- return int8_matmul_mixed_dequanitze (
218+ return int8_matmul_mixed_dequantize (
219219 X_int8 , self .W_int8 .t (), state_X , self .state_W , self .bias
220220 ).view (* x .size ()[:- 1 ], - 1 )
221221
0 commit comments