Skip to content

Commit dfc4668

Browse files
int8: specify CUDA stream for int8 ops
1 parent d231db7 commit dfc4668

File tree

6 files changed

+62
-57
lines changed

6 files changed

+62
-57
lines changed

bitsandbytes/functional.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,7 @@ def is_on_gpu(tensors: Iterable[torch.Tensor]):
442442

443443

444444
def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream:
445-
stream = torch.cuda.current_stream(tensor.device)
446-
return stream
445+
return torch.cuda.current_stream(tensor.device)
447446

448447

449448
def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
@@ -461,8 +460,8 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
461460
"""
462461
if A is None:
463462
return None
464-
else:
465-
return ct.c_void_p(A.data.data_ptr())
463+
464+
return ct.c_void_p(A.data_ptr())
466465

467466

468467
def pre_call(device):
@@ -2323,11 +2322,12 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32):
23232322
ptrC = get_ptr(out)
23242323
ptrRowScale = get_ptr(None)
23252324
m, n, k, lda, ldb, ldc = map(ct.c_int32, (m, n, k, lda, ldb, ldc))
2325+
stream = get_tensor_stream(A)
23262326

23272327
if dtype == torch.int32:
2328-
has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
2328+
has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
23292329
else:
2330-
has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
2330+
has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
23312331

23322332
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
23332333
raise NotImplementedError("igemmlt not implemented!")
@@ -2373,13 +2373,7 @@ def mm_dequant(
23732373

23742374
with torch.cuda.device_of(A):
23752375
lib.cdequant_mm_int32_fp16(
2376-
ptrA,
2377-
ptrRowStats,
2378-
ptrColStats,
2379-
ptrOut,
2380-
ptrBias,
2381-
numRows,
2382-
numCols,
2376+
ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, get_tensor_stream(A)
23832377
)
23842378

23852379
return out
@@ -2428,7 +2422,14 @@ def get_row_absmax(A, threshold=0.0):
24282422
is_on_gpu([A])
24292423

24302424
with torch.cuda.device_of(A):
2431-
lib.cget_row_stats(get_ptr(A), get_ptr(row_stats), ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
2425+
lib.cget_row_stats(
2426+
get_ptr(A),
2427+
get_ptr(row_stats),
2428+
ct.c_float(threshold),
2429+
ct.c_int32(rows),
2430+
ct.c_int32(cols),
2431+
get_tensor_stream(A),
2432+
)
24322433

24332434
return row_stats
24342435

@@ -2547,12 +2548,16 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
25472548
rows = prod(A.shape[:-1])
25482549
cols = A.shape[-1]
25492550

2550-
row_stats = torch.empty((rows,), device=A.device, dtype=torch.float32)
2551+
row_stats = torch.empty(rows, device=A.device, dtype=torch.float32)
25512552
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
25522553

25532554
if threshold > 0.0:
25542555
# TODO we could improve perf of this
2555-
coo_tensor = extract_outliers_new(A, threshold)
2556+
2557+
# A.masked_fill(A.abs() < threshold, 0.0).to_sparse_coo()
2558+
# coo_tensor = extract_outliers_new(A, threshold)
2559+
coo_tensor = torch.masked_fill(A, A.abs() < threshold, 0.0).to_sparse_coo()
2560+
25562561
else:
25572562
coo_tensor = None
25582563

@@ -2564,6 +2569,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
25642569
ct.c_float(threshold),
25652570
ct.c_int32(rows),
25662571
ct.c_int32(cols),
2572+
get_tensor_stream(A),
25672573
)
25682574

25692575
return out_row, row_stats, coo_tensor

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,8 @@ def forward(self, x: torch.Tensor):
481481
x = x.to(self.compute_dtype)
482482

483483
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
484-
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
485-
486-
out = out.to(inp_dtype)
487484

485+
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
488486
return out
489487

490488

csrc/kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3558,6 +3558,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35583558
const int warp_idx = threadIdx.x / 32;
35593559
const int warp_lane = threadIdx.x % 32;
35603560
const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
3561+
const int offset_B = ldb*row_B;
35613562
const int num_values_8bit = num_values_4bit/2;
35623563
float local_C = 0.0f;
35633564

@@ -3578,7 +3579,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35783579
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
35793580
{
35803581
const int inner_idx_halved = inner_idx/2;
3581-
const int offset_B = ldb*row_B;
35823582
const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize));
35833583
//int absidx = ((2*offset_B)+inner_idx)/blocksize;
35843584
local_absmax = __ldg(&(absmax[absidx]));

csrc/ops.cu

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,8 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
423423
const int8_t * B,
424424
void * C,
425425
float * row_scale,
426-
int lda, int ldb, int ldc
426+
int lda, int ldb, int ldc,
427+
cudaStream_t stream
427428
) {
428429

429430
// Calculate C = A^T @ B, in col-major layout.
@@ -461,7 +462,7 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
461462
B, bDesc, &beta,
462463
(int32_t*)C, cDesc,
463464
(int32_t*)C, cDesc,
464-
NULL, NULL, 0, 0
465+
NULL, NULL, 0, stream
465466
));
466467
} else {
467468
if (!SCALE_ROWS) {
@@ -472,7 +473,7 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
472473
B, bDesc, &beta,
473474
(int8_t*)C, cDesc,
474475
(int8_t*)C, cDesc,
475-
NULL, NULL, 0, 0
476+
NULL, NULL, 0, stream
476477
));
477478
} else {
478479
cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST;
@@ -489,7 +490,7 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
489490
B, bDesc, &beta,
490491
(int8_t*)C, cDesc,
491492
(int8_t*)C, cDesc,
492-
NULL, NULL, 0, 0
493+
NULL, NULL, 0, stream
493494
));
494495
}
495496
}
@@ -510,23 +511,23 @@ int fill_up_to_nearest_multiple(int value, int multiple)
510511
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
511512
}
512513

513-
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols)
514+
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, cudaStream_t stream)
514515
{
515516
const int threads = 512;
516517
const int num_per_thread = 4;
517518
const int num_per_block = threads * num_per_thread;
518519
const int n = numRows*numCols;
519520
const int num_blocks = (n + num_per_block - 1) / num_per_block;
520521

521-
kdequant_mm_int32_fp16<num_per_thread, threads><<<num_blocks, threads>>>(A, rowStats, colStats, out, bias, numRows, numCols, n);
522+
kdequant_mm_int32_fp16<num_per_thread, threads><<<num_blocks, threads, 0, stream>>>(A, rowStats, colStats, out, bias, numRows, numCols, n);
522523
CUDA_CHECK_RETURN(cudaPeekAtLastError());
523524
}
524525

525-
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols) {
526+
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
526527
if (threshold == 0.0) {
527-
kInt8VectorQuant<half, 1024, 0><<<rows, 1024>>>(A, out, rowStats, threshold, rows, cols);
528+
kInt8VectorQuant<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
528529
} else {
529-
kInt8VectorQuant<half, 1024, 1><<<rows, 1024>>>(A, out, rowStats, threshold, rows, cols);
530+
kInt8VectorQuant<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
530531
}
531532
CUDA_CHECK_RETURN(cudaPeekAtLastError());
532533
}
@@ -553,11 +554,11 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
553554

554555
}
555556

556-
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols) {
557+
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
557558
if (threshold == 0.0)
558-
kgetRowStats<half, 1024, 0><<<rows, 1024>>>(A, rowStats, threshold, rows, cols);
559+
kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
559560
else
560-
kgetRowStats<half, 1024, 1><<<rows, 1024>>>(A, rowStats, threshold, rows, cols);
561+
kgetRowStats<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
561562
CUDA_CHECK_RETURN(cudaPeekAtLastError());
562563
}
563564

@@ -795,9 +796,9 @@ template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx
795796
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
796797
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
797798

798-
template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
799-
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
800-
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
799+
template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
800+
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
801+
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
801802

802803
template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
803804
template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);

csrc/ops.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,16 @@ void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, i
171171
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
172172
long long int strideA, long long int strideB, long long int strideC, int batchCount);
173173

174-
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
174+
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
175175

176176
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
177177
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
178-
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols);
178+
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
179179
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
180-
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols);
180+
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
181181
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
182182
int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
183-
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
183+
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
184184

185185
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
186186

csrc/pythonInterface.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,14 @@ void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRo
175175
void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
176176
void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }
177177

178-
int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
179-
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
178+
int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
179+
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
180180
}
181-
int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
182-
return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
181+
int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
182+
return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
183183
}
184-
int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
185-
return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
184+
int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
185+
return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
186186
}
187187

188188
void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
@@ -308,14 +308,14 @@ extern "C"
308308
Context *get_context(){ return new Context(); }
309309
ContextCusparse *get_cusparse(){ return new ContextCusparse(); }
310310

311-
int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
312-
return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
311+
int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
312+
return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
313313
}
314-
int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
315-
return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
314+
int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
315+
return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
316316
}
317-
int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
318-
return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
317+
int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
318+
return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
319319
}
320320

321321
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
@@ -333,15 +333,15 @@ extern "C"
333333
MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
334334
MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
335335

336-
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols)
337-
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols); }
336+
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream)
337+
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); }
338338
void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
339339
{ getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); }
340-
void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols) {
341-
getRowStats(A, rowStats, threshold, rows, cols);
340+
void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
341+
getRowStats(A, rowStats, threshold, rows, cols, stream);
342342
}
343-
void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols) {
344-
int8VectorQuant(A, out, rowStats, threshold, rows, cols);
343+
void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
344+
int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream);
345345
}
346346
void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols)
347347
{ doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); }

0 commit comments

Comments
 (0)