Skip to content

Commit f93062b

Browse files
Merge branch 'main' into cleanup
2 parents 63db654 + 5b01589 commit f93062b

File tree

9 files changed

+304
-68
lines changed

9 files changed

+304
-68
lines changed

.github/workflows/python-package.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ on:
1414
- "requirements*.txt"
1515
- "setup.py"
1616
- "pyproject.toml"
17-
- "pytest.ini"
1817
release:
1918
types: [published]
2019
workflow_dispatch: {} # Allow manual trigger

bitsandbytes/triton/int8_matmul_mixed_dequantize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
99
else:
1010
import triton
1111
import triton.language as tl
12-
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
12+
13+
from .matmul_perf_model import early_config_prune, estimate_matmul_time
1314

1415
# This is a matmul kernel based on triton.ops.matmul
1516
# It is modified to support rowwise quantized input and global quantized weight

bitsandbytes/triton/int8_matmul_rowwise_dequantize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
99
else:
1010
import triton
1111
import triton.language as tl
12-
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
12+
13+
from .matmul_perf_model import early_config_prune, estimate_matmul_time
1314

1415
# This is a matmul kernel based on triton.ops.matmul
1516
# It is modified to support rowwise quantized input and columnwise quantized weight
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Adapted from https://github.com/triton-lang/kernels/blob/eeeebdd8be7d13629de22d600621e6234057eed3/kernels/matmul_perf_model.py
2+
# https://github.com/triton-lang/kernels is licensed under the MIT License.
3+
4+
import functools
5+
import heapq
6+
7+
import torch
8+
9+
from triton import cdiv
10+
from triton.runtime import driver
11+
from triton.testing import (
12+
get_dram_gbps,
13+
get_max_simd_tflops,
14+
get_max_tensorcore_tflops,
15+
nvsmi,
16+
)
17+
18+
19+
@functools.lru_cache
20+
def get_clock_rate_in_khz():
21+
try:
22+
return nvsmi(["clocks.max.sm"])[0] * 1e3
23+
except FileNotFoundError:
24+
import pynvml
25+
26+
pynvml.nvmlInit()
27+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
28+
return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3
29+
30+
31+
def get_tensorcore_tflops(device, num_ctas, num_warps, dtype):
32+
"""return compute throughput in TOPS"""
33+
total_warps = num_ctas * min(num_warps, 4)
34+
num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
35+
tflops = (
36+
min(num_subcores, total_warps)
37+
/ num_subcores
38+
* get_max_tensorcore_tflops(dtype, get_clock_rate_in_khz(), device)
39+
)
40+
return tflops
41+
42+
43+
def get_simd_tflops(device, num_ctas, num_warps, dtype):
44+
"""return compute throughput in TOPS"""
45+
total_warps = num_ctas * min(num_warps, 4)
46+
num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
47+
tflops = (
48+
min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device)
49+
)
50+
return tflops
51+
52+
53+
def get_tflops(device, num_ctas, num_warps, dtype):
54+
capability = torch.cuda.get_device_capability(device)
55+
if capability[0] < 8 and dtype == torch.float32:
56+
return get_simd_tflops(device, num_ctas, num_warps, dtype)
57+
return get_tensorcore_tflops(device, num_ctas, num_warps, dtype)
58+
59+
60+
def estimate_matmul_time(
61+
# backend, device,
62+
num_warps,
63+
num_stages, #
64+
A,
65+
B,
66+
C, #
67+
M,
68+
N,
69+
K, #
70+
BLOCK_M,
71+
BLOCK_N,
72+
BLOCK_K,
73+
SPLIT_K, #
74+
debug=False,
75+
**kwargs, #
76+
):
77+
"""return estimated running time in ms
78+
= max(compute, loading) + store"""
79+
device = torch.cuda.current_device()
80+
dtype = A.dtype
81+
dtsize = A.element_size()
82+
83+
num_cta_m = cdiv(M, BLOCK_M)
84+
num_cta_n = cdiv(N, BLOCK_N)
85+
num_cta_k = SPLIT_K
86+
num_ctas = num_cta_m * num_cta_n * num_cta_k
87+
88+
# If the input is smaller than the block size
89+
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
90+
91+
# time to compute
92+
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
93+
tput = get_tflops(device, num_ctas, num_warps, dtype)
94+
compute_ms = total_ops / tput
95+
96+
# time to load data
97+
num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"]
98+
active_cta_ratio = min(1, num_ctas / num_sm)
99+
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
100+
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
101+
dram_bw = get_dram_gbps(device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
102+
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
103+
# assume 80% of (following) loads are in L2 cache
104+
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
105+
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
106+
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
107+
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
108+
# total
109+
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
110+
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
111+
# loading time in ms
112+
load_ms = total_dram / dram_bw + total_l2 / l2_bw
113+
114+
# estimate storing time
115+
store_bw = dram_bw * 0.6 # :o
116+
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
117+
if SPLIT_K == 1:
118+
store_ms = store_c_dram / store_bw
119+
else:
120+
reduce_bw = store_bw
121+
store_ms = store_c_dram / reduce_bw
122+
# c.zero_()
123+
zero_ms = M * N * 2 / (1024 * 1024) / store_bw
124+
store_ms += zero_ms
125+
126+
total_time_ms = max(compute_ms, load_ms) + store_ms
127+
if debug:
128+
print(
129+
f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, "
130+
f"loading time: {load_ms}ms, store time: {store_ms}ms, "
131+
f"Activate CTAs: {active_cta_ratio*100}%"
132+
)
133+
return total_time_ms
134+
135+
136+
def early_config_prune(configs, named_args, **kwargs):
137+
device = torch.cuda.current_device()
138+
capability = torch.cuda.get_device_capability()
139+
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
140+
dtsize = named_args["A"].element_size()
141+
dtype = named_args["A"].dtype
142+
143+
# 1. make sure we have enough smem
144+
pruned_configs = []
145+
for config in configs:
146+
kw = config.kwargs
147+
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
148+
kw["BLOCK_M"],
149+
kw["BLOCK_N"],
150+
kw["BLOCK_K"],
151+
config.num_stages,
152+
)
153+
154+
max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"]
155+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
156+
if required_shared_memory <= max_shared_memory:
157+
pruned_configs.append(config)
158+
configs = pruned_configs
159+
160+
# Some dtypes do not allow atomic_add
161+
if dtype not in [torch.float16, torch.float32]:
162+
configs = [config for config in configs if config.kwargs["SPLIT_K"] == 1]
163+
164+
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
165+
configs_map = {}
166+
for config in configs:
167+
kw = config.kwargs
168+
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = (
169+
kw["BLOCK_M"],
170+
kw["BLOCK_N"],
171+
kw["BLOCK_K"],
172+
kw["SPLIT_K"],
173+
config.num_warps,
174+
config.num_stages,
175+
)
176+
177+
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
178+
if key in configs_map:
179+
configs_map[key].append((config, num_stages))
180+
else:
181+
configs_map[key] = [(config, num_stages)]
182+
183+
pruned_configs = []
184+
for k, v in configs_map.items():
185+
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
186+
if capability[0] >= 8:
187+
# compute cycles (only works for ampere GPUs)
188+
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
189+
mma_cycles = mmas / min(4, num_warps) * 8
190+
191+
ldgsts_latency = 300 # Does this matter?
192+
optimal_num_stages = ldgsts_latency / mma_cycles
193+
194+
# nearest stages, prefer large #stages
195+
nearest = heapq.nsmallest(
196+
2,
197+
v,
198+
key=lambda x: (
199+
10 + abs(x[1] - optimal_num_stages)
200+
if (x[1] - optimal_num_stages) < 0
201+
else x[1] - optimal_num_stages
202+
),
203+
)
204+
205+
for n in nearest:
206+
pruned_configs.append(n[0])
207+
else: # Volta & Turing only supports num_stages <= 2
208+
random_config = v[0][0]
209+
random_config.num_stages = 2
210+
pruned_configs.append(random_config)
211+
return pruned_configs

pyproject.toml

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,94 @@
11
[build-system]
2-
requires = [ "setuptools", "wheel" ]
2+
requires = ["setuptools >= 63.0.0"]
33
build-backend = "setuptools.build_meta"
44

5+
[project]
6+
name = "bitsandbytes"
7+
dynamic = ["version"]
8+
description = "k-bit optimizers and matrix multiplication routines."
9+
authors = [{name="Tim Dettmers", email="[email protected]"}]
10+
requires-python = ">=3.8"
11+
readme = "README.md"
12+
license = {file="LICENSE"}
13+
keywords = [
14+
"gpu",
15+
"optimizers",
16+
"optimization",
17+
"8-bit",
18+
"quantization",
19+
"compression"
20+
]
21+
classifiers = [
22+
"Development Status :: 4 - Beta",
23+
"License :: OSI Approved :: MIT License",
24+
"Environment :: GPU :: NVIDIA CUDA :: 11",
25+
"Environment :: GPU :: NVIDIA CUDA :: 12",
26+
"Intended Audience :: Developers",
27+
"Intended Audience :: Science/Research",
28+
"Operating System :: POSIX :: Linux",
29+
"Operating System :: MacOS",
30+
"Operating System :: Microsoft :: Windows",
31+
"Programming Language :: C++",
32+
"Programming Language :: Python :: Implementation :: CPython",
33+
"Programming Language :: Python :: 3.8",
34+
"Programming Language :: Python :: 3.9",
35+
"Programming Language :: Python :: 3.10",
36+
"Programming Language :: Python :: 3.11",
37+
"Programming Language :: Python :: 3.12",
38+
"Topic :: Scientific/Engineering :: Artificial Intelligence"
39+
]
40+
dependencies = [
41+
"torch>=1.11,!=1.12.0",
42+
"numpy>=1.17"
43+
]
44+
45+
[project.optional-dependencies]
46+
benchmark = ["pandas", "matplotlib"]
47+
docs = ["hf-doc-builder==0.5.0"]
48+
dev = [
49+
"bitsandbytes[test]",
50+
"build>=1.0.0,<2",
51+
"ruff==0.6.9",
52+
"pre-commit>=3.5.0,<4",
53+
"wheel>=0.42,<1"
54+
]
55+
test = [
56+
"einops~=0.6.0",
57+
"lion-pytorch==0.0.6",
58+
"pytest~=7.4",
59+
"scipy>=1.10.1,<2; python_version < '3.9'",
60+
"scipy>=1.11.4,<2; python_version >= '3.9'",
61+
"transformers>=4.30.1,<5"
62+
]
63+
triton = ["triton~=2.0.0; sys_platform=='linux' and platform_machine=='x86_64'"]
64+
65+
[project.urls]
66+
homepage = "https://github.com/TimDettmers/bitsandbytes"
67+
changelog = "https://github.com/TimDettmers/bitsandbytes/blob/main/CHANGELOG.md"
68+
docs = "https://huggingface.co/docs/bitsandbytes/main"
69+
issues = "https://github.com/TimDettmers/bitsandbytes/issues"
70+
71+
[tool.setuptools]
72+
package-data = { "*" = ["libbitsandbytes*.*"] }
73+
74+
[tool.setuptools.dynamic]
75+
version = {attr = "bitsandbytes.__version__"}
76+
77+
[tool.pytest.ini_options]
78+
addopts = "-rP"
79+
# ; --cov=bitsandbytes
80+
# ; # contexts: record which test ran which line; can be seen in html coverage report
81+
# ; --cov-context=test
82+
# ; --cov-report html
83+
log_cli = true
84+
log_cli_level = "INFO"
85+
log_file = "logs/pytest.log"
86+
markers = [
87+
"benchmark: mark test as a benchmark",
88+
"deprecated: mark test as covering a deprecated feature",
89+
"slow: mark test as slow",
90+
]
91+
592
[tool.ruff]
693
src = [
794
"bitsandbytes",

pytest.ini

Lines changed: 0 additions & 14 deletions
This file was deleted.

requirements-ci.txt

Lines changed: 0 additions & 6 deletions
This file was deleted.

requirements-dev.txt

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)