Skip to content

Commit 510a880

Browse files
Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup
1 parent ca372f2 commit 510a880

File tree

11 files changed

+48
-99
lines changed

11 files changed

+48
-99
lines changed

.github/scripts/build-cuda.sh

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90"
88
[[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????}
99
[[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???}
1010
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja
11-
for NO_CUBLASLT in ON OFF; do
12-
if [ "${build_os:0:6}" == ubuntu ]; then
13-
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
14-
echo "Using image $image"
15-
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
16-
"apt-get update \
17-
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
18-
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \
19-
&& cmake --build ."
20-
else
21-
pip install cmake==3.28.3
22-
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S .
23-
cmake --build . --config Release
24-
fi
25-
done
11+
12+
if [ "${build_os:0:6}" == ubuntu ]; then
13+
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
14+
echo "Using image $image"
15+
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
16+
"apt-get update \
17+
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
18+
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \
19+
&& cmake --build ."
20+
else
21+
pip install cmake==3.28.3
22+
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S .
23+
cmake --build . --config Release
24+
fi
25+
2626

2727
output_dir="output/${build_os}/${build_arch}"
2828
mkdir -p "${output_dir}"

CMakeLists.txt

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# For MSVC: `cmake -B build . && cmake --build build --config Release`
55
# You can also use the following options and variables
66
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
7-
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
87
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
98
# is whatever CMake finds on your path.
109
# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC.
@@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
4746
if(APPLE)
4847
message(FATAL_ERROR "CUDA is not supported on macOS" )
4948
endif()
50-
option(NO_CUBLASLT "Disable CUBLAS" OFF)
5149
set(BUILD_CUDA ON)
5250
set(BUILD_MPS OFF)
53-
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
5451
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
5552
if(NOT APPLE)
5653
message(FATAL_ERROR "MPS is only supported on macOS" )
@@ -166,9 +163,6 @@ if(BUILD_CUDA)
166163
list(APPEND SRC_FILES ${CUDA_FILES})
167164

168165
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
169-
if(NO_CUBLASLT)
170-
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
171-
endif()
172166
add_compile_definitions(BUILD_CUDA)
173167
elseif(BUILD_MPS)
174168
if(NOT APPLE)
@@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include)
212206

213207
if(BUILD_CUDA)
214208
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
215-
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse)
216-
if(NO_CUBLASLT)
217-
target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT)
218-
else()
219-
target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt)
220-
endif()
221-
209+
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)
222210
set_target_properties(bitsandbytes
223211
PROPERTIES
224212
CUDA_SEPARABLE_COMPILATION ON

bitsandbytes/autograd/_functions.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,9 @@ def forward(
283283
B: torch.Tensor,
284284
out=None,
285285
bias: Optional[torch.Tensor] = None,
286-
state=MatmulLtState,
286+
state: MatmulLtState = None,
287287
):
288-
# state = state or MatmulLtState()
288+
state = state or MatmulLtState()
289289

290290
# default of pytorch behavior if inputs are empty
291291
ctx.is_empty = False
@@ -318,7 +318,7 @@ def forward(
318318
if is_transposed:
319319
B = B.contiguous()
320320

321-
if (state.is_training and not has_grad) or state.CB is None:
321+
if (state.is_training and not has_grad) or state.SCB is None:
322322
state.reset_grads()
323323

324324
# 2. Quantize B
@@ -347,7 +347,7 @@ def forward(
347347
outliers = state.CB[:, state.idx].clone()
348348
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
349349
else:
350-
subA = state.subB = None
350+
subA = None
351351

352352
# 3. Int8 Matmul
353353
out32, Sout32 = F.igemmlt(CA, state.CB)
@@ -377,7 +377,11 @@ def forward(
377377
ctx.save_for_backward(None, None)
378378

379379
output_shape = (*input_shape[:-1], state.CB.shape[0])
380-
return output.reshape(output_shape).clone()
380+
381+
if len(input_shape) == 3:
382+
return output.view(output_shape).clone()
383+
else:
384+
return output
381385

382386
@staticmethod
383387
def backward(ctx, grad_output):
@@ -400,18 +404,16 @@ def backward(ctx, grad_output):
400404
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
401405

402406
Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
403-
if req_gradB:
404-
# grad_output.T @ A
405-
# grad_weight = grad_output.t().mm(A)
406-
grad_B = torch.matmul(grad_output.t(), A)
407-
if state.threshold > 0.0 and subA is not None:
408-
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
409407
# if req_gradB:
410-
#
411-
# gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t())
412-
# grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
408+
409+
# grad_B = torch.matmul(grad_output.t(), A)
413410
# if state.threshold > 0.0 and subA is not None:
414411
# grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
412+
if req_gradB:
413+
gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t())
414+
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
415+
if state.threshold > 0.0 and subA is not None:
416+
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
415417

416418
if req_gradA:
417419
# grad_output @ B.T

bitsandbytes/cextension.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
3737
3838
The library is not guaranteed to exist at the returned path.
3939
"""
40-
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}"
41-
if not cuda_specs.has_cublaslt:
42-
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
43-
library_name += "_nocublaslt"
44-
library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}"
40+
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
4541

4642
override_value = os.environ.get("BNB_CUDA_VERSION")
4743
if override_value:

bitsandbytes/cuda_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class CUDASpecs:
1111
cuda_version_tuple: Tuple[int, int]
1212

1313
@property
14-
def has_cublaslt(self) -> bool:
14+
def has_imma(self) -> bool:
1515
return self.highest_compute_capability >= (7, 5)
1616

1717

bitsandbytes/diagnostics/cuda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
134134

135135
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
136136

137-
# 7.5 is the minimum CC for cublaslt
138-
if not cuda_specs.has_cublaslt:
137+
# 7.5 is the minimum CC for int8 tensor cores
138+
if not cuda_specs.has_imma:
139139
print_dedented(
140140
"""
141141
WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!

csrc/ops.cu

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,6 @@ int roundoff(int v, int d) {
314314
}
315315

316316

317-
#ifdef NO_CUBLASLT
318-
#else
319317
template<int ORDER> cublasLtOrder_t get_order()
320318
{
321319
switch(ORDER)
@@ -347,7 +345,6 @@ template cublasLtOrder_t get_order<COL>();
347345
template cublasLtOrder_t get_order<COL32>();
348346
template cublasLtOrder_t get_order<COL_TURING>();
349347
template cublasLtOrder_t get_order<COL_AMPERE>();
350-
#endif
351348

352349

353350
template<int ORDER> int get_leading_dim(int dim1, int dim2)
@@ -379,8 +376,6 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
379376

380377
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
381378
{
382-
#ifdef NO_CUBLASLT
383-
#else
384379
cublasLtOrder_t orderA = get_order<SRC>();
385380
cublasLtOrder_t orderOut = get_order<TARGET>();
386381
int ldA = get_leading_dim<SRC>(dim1, dim2);
@@ -419,7 +414,6 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
419414
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
420415
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
421416
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
422-
#endif
423417
}
424418

425419
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
@@ -513,9 +507,6 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
513507

514508
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
515509
{
516-
#ifdef NO_CUBLASLT
517-
return ERR_NOT_IMPLEMENTED;
518-
#else
519510
int has_error = 0;
520511
cublasLtMatmulDesc_t matmulDesc = NULL;
521512
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
@@ -570,7 +561,6 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
570561
printf("error detected");
571562

572563
return has_error;
573-
#endif // NO_CUBLASLT
574564
}
575565

576566
int fill_up_to_nearest_multiple(int value, int multiple)
@@ -681,10 +671,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
681671

682672
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
683673
{
684-
685-
#ifdef NO_CUBLASLT
686-
#else
687-
688674
cusparseSpMatDescr_t descA;
689675
cusparseDnMatDescr_t descB, descC;
690676

@@ -731,7 +717,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
731717
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
732718
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
733719
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
734-
#endif
735720
}
736721

737722
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)

tests/conftest.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
def pytest_runtest_call(item):
88
try:
99
item.runtest()
10-
except NotImplementedError as nie:
11-
if "NO_CUBLASLT" in str(nie):
12-
pytest.skip("CUBLASLT not available")
13-
raise
1410
except AssertionError as ae:
1511
if str(ae) == "Torch not compiled with CUDA enabled":
1612
pytest.skip("Torch not compiled with CUDA enabled")

tests/test_cuda_setup_evaluator.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,6 @@ def cuda120_spec() -> CUDASpecs:
1313
)
1414

1515

16-
@pytest.fixture
17-
def cuda111_noblas_spec() -> CUDASpecs:
18-
return CUDASpecs(
19-
cuda_version_string="111",
20-
highest_compute_capability=(7, 2),
21-
cuda_version_tuple=(11, 1),
22-
)
23-
24-
2516
def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
2617
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
2718
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"
@@ -31,14 +22,3 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
3122
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
3223
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
3324
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?
34-
35-
36-
def test_get_cuda_bnb_library_path_override_nocublaslt(monkeypatch, cuda111_noblas_spec, caplog):
37-
monkeypatch.setenv("BNB_CUDA_VERSION", "125")
38-
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda125_nocublaslt"
39-
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?
40-
41-
42-
def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
43-
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
44-
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"

tests/test_linear8bitlt.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,13 @@ def test_linear_no_igemmlt():
6969

7070
fx_ours = linear_custom(x_ours).float()
7171
(fx_ours * grad_proj).mean().backward()
72+
73+
assert linear_custom.state.CB is not None
74+
assert not linear_custom.state.has_fp16_weights
7275
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
7376
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
74-
assert not linear_custom.state.has_fp16_weights
75-
assert linear_custom.state.CB is not None
76-
assert linear_custom.state.CxB is None
77+
78+
# assert linear_custom.state.CxB is None
7779

7880

7981
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))

0 commit comments

Comments
 (0)