Skip to content

Commit b1f80b8

Browse files
authored
[CUDA] Branchless NF4/FP4 kDequantizeBlockwise kernel for faster dequantization (#1746)
* Added branchless LUT-based dequantization for FP4 and NF4 * Added extra command line options to control reproducibility * Restore FP4 quantization/dequantization order
1 parent c9bce2b commit b1f80b8

File tree

2 files changed

+106
-116
lines changed

2 files changed

+106
-116
lines changed

benchmarking/inference_benchmark.py

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
--batches BATCHES [BATCHES ...]
2222
--input-length INPUT_LENGTH
2323
--out-dir OUT_DIR
24+
--iterations ITERATIONS
25+
--warmup-runs WARMUP_RUNS
26+
--output-length OUTPUT_LENGTH
2427
"""
2528

2629
import argparse
@@ -30,6 +33,9 @@
3033
from optimum_benchmark.logging_utils import setup_logging
3134
import torch
3235

36+
torch.backends.cudnn.benchmark = False
37+
torch.backends.cudnn.deterministic = True
38+
3339
BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8
3440

3541
WEIGHTS_CONFIGS = {
@@ -73,9 +79,8 @@
7379
},
7480
}
7581

76-
if __name__ == "__main__":
77-
setup_logging(level="INFO")
7882

83+
def parse_args():
7984
parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool")
8085

8186
parser.add_argument("model_id", type=str, help="The model checkpoint to use.")
@@ -98,37 +103,73 @@
98103

99104
parser.add_argument("--out-dir", type=str, default="reports")
100105

101-
args = parser.parse_args()
106+
parser.add_argument("--iterations", type=int, default=10, help="Number of iterations for each benchmark run")
107+
parser.add_argument(
108+
"--warmup-runs", type=int, default=10, help="Number of warmup runs to discard before measurement"
109+
)
110+
parser.add_argument(
111+
"--output-length",
112+
type=int,
113+
default=64,
114+
help="If set, `max_new_tokens` and `min_new_tokens` will be set to this value.",
115+
)
116+
117+
return parser.parse_args()
118+
119+
120+
def run_benchmark(args, config, batch_size):
121+
launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="warn", start_method="spawn")
122+
scenario_config = InferenceConfig(
123+
latency=True,
124+
memory=True,
125+
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
126+
iterations=args.iterations,
127+
warmup_runs=args.warmup_runs,
128+
# set duration to 0 to disable the duration-based stopping criterion
129+
# this is IMPORTANT to ensure that all benchmarks run the same number of operations, regardless of hardware speed/bottlenecks
130+
duration=0,
131+
# for consistent results, set a fixed min and max for output tokens
132+
generate_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length},
133+
forward_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length},
134+
)
135+
136+
backend_config = PyTorchConfig(
137+
device="cuda",
138+
device_ids="0",
139+
device_map="auto",
140+
no_weights=False,
141+
model=args.model_id,
142+
**WEIGHTS_CONFIGS[config],
143+
)
144+
145+
test_name = (
146+
f"benchmark-{config}"
147+
f"-bsz-{batch_size}"
148+
f"-isz-{args.input_length}"
149+
f"-osz-{args.output_length}"
150+
f"-iter-{args.iterations}"
151+
f"-wrmup-{args.warmup_runs}"
152+
)
153+
benchmark_config = BenchmarkConfig(
154+
name=test_name,
155+
scenario=scenario_config,
156+
launcher=launcher_config,
157+
backend=backend_config,
158+
)
159+
160+
out_path = out_dir / (test_name + ".json")
161+
print(f"[{test_name}] Starting:")
162+
benchmark_report = Benchmark.launch(benchmark_config)
163+
benchmark_report.save_json(out_path)
164+
165+
166+
if __name__ == "__main__":
167+
setup_logging(level="INFO")
168+
args = parse_args()
102169

103170
out_dir = Path(args.out_dir)
104171
out_dir.mkdir(parents=True, exist_ok=True)
105172

106173
for batch_size in args.batches:
107-
print(f"Benchmarking batch size: {batch_size}")
108174
for config in args.configs:
109-
launcher_config = ProcessConfig(device_isolation=True, start_method="spawn")
110-
scenario_config = InferenceConfig(
111-
latency=True,
112-
memory=True,
113-
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
114-
)
115-
backend_config = PyTorchConfig(
116-
device="cuda",
117-
device_ids="0",
118-
device_map="auto",
119-
no_weights=False,
120-
model=args.model_id,
121-
**WEIGHTS_CONFIGS[config],
122-
)
123-
benchmark_config = BenchmarkConfig(
124-
name=f"benchmark-{config}-bsz{batch_size}",
125-
scenario=scenario_config,
126-
launcher=launcher_config,
127-
backend=backend_config,
128-
)
129-
130-
out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json"
131-
132-
benchmark_report = Benchmark.launch(benchmark_config)
133-
benchmark_report.log()
134-
benchmark_report.save_json(out_path)
175+
run_benchmark(args, config, batch_size)

csrc/kernels.cu

Lines changed: 35 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,34 @@
2121
#define NUM 4
2222
#define NUM_BLOCK 4096
2323

24-
__device__ static float nf4_data[16] = {
25-
-1.0,
26-
-0.6961928009986877,
27-
-0.5250730514526367,
28-
-0.39491748809814453,
29-
-0.28444138169288635,
30-
-0.18477343022823334,
31-
-0.09105003625154495,
32-
0.0,
33-
0.07958029955625534,
34-
0.16093020141124725,
35-
0.24611230194568634,
36-
0.33791524171829224,
37-
0.44070982933044434,
38-
0.5626170039176941,
39-
0.7229568362236023,
40-
1.0
24+
__device__ static float fp4_dequantization_lut[8] = {
25+
0.0f, // 0b000
26+
0.005208333333f, // 0b001
27+
0.66666667f, // 0b010
28+
1.0f, // 0b011
29+
0.33333333f, // 0b100
30+
0.5f, // 0b101
31+
0.16666667f, // 0b110
32+
0.25f // 0b111
33+
};
34+
35+
__device__ static float nf4_dequantization_lut[16] = {
36+
-1.0f, // 0b0000
37+
-0.6961928009986877f, // 0b0001
38+
-0.5250730514526367f, // 0b0010
39+
-0.39491748809814453f, // 0b0011
40+
-0.28444138169288635f, // 0b0100
41+
-0.18477343022823334f, // 0b0101
42+
-0.09105003625154495f, // 0b0110
43+
0.0f, // 0b0111
44+
0.07958029955625534f, // 0b1000
45+
0.16093020141124725f, // 0b1001
46+
0.24611230194568634f, // 0b1010
47+
0.33791524171829224f, // 0b1011
48+
0.44070982933044434f, // 0b1100
49+
0.5626170039176941f, // 0b1101
50+
0.7229568362236023f, // 0b1110
51+
1.0f // 0b1111
4152
};
4253

4354
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
@@ -51,27 +62,9 @@ __device__ float atomicMax(float* address, float val) {
5162
return __int_as_float(old);
5263
}
5364

54-
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) {
55-
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
56-
if ((val & 0b0100) == 4) // 0
57-
if ((val & 0b0010) == 2) // 01
58-
if ((val & 0b0001) == 1) // 111
59-
return 0.25000000f * absmax * sign; // 1111
60-
else
61-
return 0.16666667f * absmax * sign; // 1110
62-
else if ((val & 0b0001) == 1) // 110
63-
return 0.50000000f * absmax * sign; // 1101
64-
else
65-
return 0.33333333f * absmax * sign; // 1100
66-
else if ((val & 0b0010) == 2) // 10
67-
if ((val & 0b0001) == 1) // 101
68-
return 1.00000000f * absmax * sign; // 1011
69-
else
70-
return 0.66666667f * absmax * sign; // 1010
71-
else if ((val & 0b0001) == 1) // 100
72-
return 5.208333333e-03f * absmax * sign; // 1001
73-
else
74-
return 0.00000000f * absmax * sign; // 1000
65+
__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) {
66+
float sign = 1.0f - 2 * ((val & 0b1000) >> 3);
67+
return fp4_dequantization_lut[val & 0b111] * sign;
7568
}
7669

7770
__device__ unsigned char dQuantizeFP4(float x) {
@@ -118,51 +111,7 @@ __device__ unsigned char dQuantizeFP4(float x) {
118111
return 0b0000 + sign;
119112
}
120113

121-
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) {
122-
123-
// the values for this tree was generated by test_normal_map_tree
124-
// in the file tests/test_functional.py
125-
if ((val & 0b1000) == 8)
126-
if ((val & 0b0100) == 4) // 1
127-
if ((val & 0b0010) == 2) // 11
128-
if ((val & 0b0001) == 1) // 111
129-
return 1.0f;
130-
else
131-
return 0.7229568362236023f;
132-
else if ((val & 0b0001) == 1) // 110
133-
return 0.5626170039176941f;
134-
else
135-
return 0.44070982933044434f;
136-
else if ((val & 0b0010) == 2) // 10
137-
if ((val & 0b0001) == 1) // 101
138-
return 0.33791524171829224f;
139-
else
140-
return 0.24611230194568634f;
141-
else if ((val & 0b0001) == 1) // 100
142-
return 0.16093020141124725f;
143-
else
144-
return 0.07958029955625534f;
145-
146-
else if ((val & 0b0100) == 4) // 0
147-
if ((val & 0b0010) == 2) // 01
148-
if ((val & 0b0001) == 1) // 011
149-
return 0.0f;
150-
else
151-
return -0.09105003625154495f;
152-
else if ((val & 0b0001) == 1) // 010
153-
return -0.18477343022823334f;
154-
else
155-
return -0.28444138169288635f;
156-
else if ((val & 0b0010) == 2) // 00
157-
if ((val & 0b0001) == 1) // 001
158-
return -0.39491748809814453f;
159-
else
160-
return -0.5250730514526367f;
161-
else if ((val & 0b0001) == 1) // 000
162-
return -0.6961928009986877f;
163-
else
164-
return -1.0f;
165-
}
114+
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; }
166115

167116
__device__ unsigned char dQuantizeNF4(float x) {
168117

@@ -510,8 +459,8 @@ __global__ void
510459
case FP4:
511460
#pragma unroll NUM_PER_TH
512461
for (int j = 0; j < NUM_PER_TH; j++) {
513-
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
514-
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
462+
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max;
463+
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max;
515464
}
516465
break;
517466
case NF4:
@@ -2352,7 +2301,7 @@ __global__ void kgemm_4bit_inference(
23522301

23532302
#pragma unroll 16
23542303
for (int i = 0; i < 16; i++)
2355-
quant_map[i] = nf4_data[i];
2304+
quant_map[i] = nf4_dequantization_lut[i];
23562305
//__shared__ T quant_map[16*160];
23572306

23582307
T local_A[2];

0 commit comments

Comments
 (0)