Skip to content

Commit 947db7c

Browse files
authored
Merge pull request #436 from akx/quanitze
Fix typo "quanitze"
2 parents 8c5c668 + 6b26402 commit 947db7c

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

benchmarking/switchback/speed_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
99
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
1010
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
11-
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
11+
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
1212

1313
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
1414

@@ -72,8 +72,8 @@ def get_time(k, fn, info_dict):
7272
get_time('standard_gx', lambda : g.matmul(w), info)
7373
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
7474
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
75-
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
76-
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
75+
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
76+
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
7777
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
7878
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
7979
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)

bitsandbytes/nn/triton_based_modules.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
1111
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
1212
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
13-
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
13+
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize
1414

1515

1616
class _switchback_global(torch.autograd.Function):
@@ -29,7 +29,7 @@ def forward(ctx, X_3D, W, bias):
2929

3030
# matmult, fused dequant and add bias
3131
# call "mixed" because we are mixing rowwise quantized and global quantized
32-
return int8_matmul_mixed_dequanitze(
32+
return int8_matmul_mixed_dequantize(
3333
X_int8, W_int8.t(), state_X, state_W, bias
3434
).view(*X_3D.size()[:-1], -1)
3535

@@ -47,7 +47,7 @@ def backward(ctx, G_3D):
4747
# so we transpose once then call .t() in the matmul
4848
G_int8, state_G = quantize_rowwise(G)
4949
W_int8, state_W = quantize_global_transpose(W)
50-
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
50+
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
5151
*G_3D.size()[:-1], -1
5252
)
5353
if ctx.needs_input_grad[1]:
@@ -119,7 +119,7 @@ def forward(ctx, X_3D, W, bias):
119119

120120
# matmult, fused dequant and add bias
121121
# call "mixed" because we are mixing rowwise quantized and global quantized
122-
return int8_matmul_mixed_dequanitze(
122+
return int8_matmul_mixed_dequantize(
123123
X_int8, W_int8.t(), state_X, state_W, bias
124124
).view(*X_3D_sz[:-1], -1)
125125

@@ -143,7 +143,7 @@ def backward(ctx, G_3D):
143143
G_int8, state_G = quantize_rowwise(G)
144144
del G
145145
W_int8 = W_int8.t().contiguous()
146-
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
146+
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
147147
*G_3D_sz[:-1], -1
148148
)
149149

@@ -215,7 +215,7 @@ def forward(self, x):
215215
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
216216
).view(*x.size()[:-1], -1)
217217
else:
218-
return int8_matmul_mixed_dequanitze(
218+
return int8_matmul_mixed_dequantize(
219219
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
220220
).view(*x.size()[:-1], -1)
221221

bitsandbytes/triton/int8_matmul_mixed_dequanitze.py renamed to bitsandbytes/triton/int8_matmul_mixed_dequantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from bitsandbytes.triton.triton_utils import is_triton_available
33

44
if not is_triton_available():
5-
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
5+
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
66
else:
77

88
import triton
@@ -136,7 +136,7 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N,
136136
tl.atomic_add(C, acc, mask=mask)
137137

138138

139-
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
139+
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
140140
device = a.device
141141
divfactor = 1. / (127. * 127.)
142142
has_bias = 0 if bias is None else 1

0 commit comments

Comments
 (0)