Skip to content

Commit 622c0ab

Browse files
authored
Merge branch 'main' into main
2 parents 38054f6 + ff389db commit 622c0ab

File tree

18 files changed

+554
-202
lines changed

18 files changed

+554
-202
lines changed

.github/FUNDING.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
open_collective: bitsandbytes

.github/scripts/build-cuda.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ elif [ "${build_arch}" = "aarch64" ]; then
1515
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;90;100;120"
1616
else
1717
# By default, target Maxwell through Hopper.
18-
build_capability="50;52;60;61;70;75;80;86;89;90"
18+
build_capability="50;60;70;75;80;86;89;90"
1919

20-
# CUDA 12.8+: Add sm100 and sm120; remove < sm75 to align with PyTorch 2.7+cu128 minimum
21-
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;86;89;90;100;120"
20+
# CUDA 12.8+: Add sm100 and sm120; remove < sm70 to align with PyTorch 2.8+cu128 minimum
21+
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="70;75;80;86;89;90;100;120"
2222
fi
2323

2424
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ bitsandbytes has the following minimum requirements for all platforms:
2626
#### Accelerator support:
2727

2828
<small>Note: this table reflects the status of the current development branch. For the latest stable release, see the
29-
[document in the v0.46.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.46.0/README.md#accelerator-support).
29+
[document in the 0.47.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.47.0/README.md#accelerator-support).
3030
</small>
3131

3232
##### Legend:

bitsandbytes/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
if hasattr(torch, "xpu") and torch.xpu.is_available():
3939
from .backends.xpu import ops as xpu_ops
4040

41-
4241
if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
4342
# In case not automatically imported
4443
import habana_frameworks.torch
@@ -76,4 +75,4 @@ def _import_backends():
7675
"optim.optimizer.MockArgs": False,
7776
}
7877

79-
__version__ = "0.47.0.dev0"
78+
__version__ = "0.48.0.dev0"

bitsandbytes/_ops.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,107 @@ def _(
327327
)
328328
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
329329
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
330+
331+
332+
torch.library.define(
333+
"bitsandbytes::optimizer_update_32bit",
334+
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()",
335+
)
336+
337+
338+
@register_fake("bitsandbytes::optimizer_update_32bit")
339+
def _(
340+
optimizer_name: str,
341+
g: torch.Tensor,
342+
p: torch.Tensor,
343+
state1: torch.Tensor,
344+
state2: Optional[torch.Tensor],
345+
unorm_vec: Optional[torch.Tensor],
346+
max_unorm: float,
347+
param_norm: float,
348+
beta1: float,
349+
beta2: float,
350+
beta3: float,
351+
alpha: float,
352+
eps: float,
353+
weight_decay: float,
354+
step: int,
355+
lr: float,
356+
gnorm_scale: float,
357+
skip_zeros=False,
358+
) -> None:
359+
torch._check(
360+
g.numel() == p.numel(),
361+
lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
362+
)
363+
compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
364+
365+
torch._check(
366+
g.dtype in compute_dtypes,
367+
lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
368+
)
369+
torch._check(
370+
g.dtype == p.dtype,
371+
lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
372+
)
373+
374+
375+
torch.library.define(
376+
"bitsandbytes::optimizer_update_8bit_blockwise",
377+
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()",
378+
)
379+
380+
381+
@register_fake("bitsandbytes::optimizer_update_8bit_blockwise")
382+
def _(
383+
optimizer_name: str,
384+
g: torch.Tensor,
385+
p: torch.Tensor,
386+
state1: torch.Tensor,
387+
state2: Optional[torch.Tensor],
388+
beta1: float,
389+
beta2: float,
390+
beta3: float,
391+
alpha: float,
392+
eps: float,
393+
step: int,
394+
lr: float,
395+
qmap1: torch.Tensor,
396+
qmap2: Optional[torch.Tensor],
397+
absmax1: torch.Tensor,
398+
absmax2: Optional[torch.Tensor],
399+
weight_decay: float,
400+
gnorm_scale: float,
401+
skip_zeros=False,
402+
) -> None:
403+
torch._check(
404+
g.numel() == p.numel(),
405+
lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
406+
)
407+
compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
408+
409+
torch._check(
410+
g.dtype in compute_dtypes,
411+
lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
412+
)
413+
torch._check(
414+
g.dtype == p.dtype,
415+
lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
416+
)
417+
torch._check(
418+
state1.dtype == torch.uint8,
419+
lambda: f"state1 must be uint8, got {state1.dtype}",
420+
)
421+
torch._check(
422+
qmap1.dtype == absmax1.dtype == torch.float32,
423+
lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
424+
)
425+
if state2 is not None:
426+
torch._check(
427+
state2.dtype == torch.uint8,
428+
lambda: f"state2 must be uint8, got {state2.dtype}",
429+
)
430+
torch._check(
431+
qmap2.dtype == absmax2.dtype == torch.float32,
432+
lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
433+
)

bitsandbytes/backends/cuda/ops.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,229 @@ def _gemv_4bit_impl(
538538
ct.c_int32(blocksize),
539539
stream,
540540
)
541+
542+
543+
"""C FUNCTIONS FOR OPTIMIZERS"""
544+
str2optimizer32bit = {
545+
"adam": (
546+
lib.cadam32bit_grad_fp32,
547+
lib.cadam32bit_grad_fp16,
548+
lib.cadam32bit_grad_bf16,
549+
),
550+
"momentum": (
551+
lib.cmomentum32bit_grad_32,
552+
lib.cmomentum32bit_grad_16,
553+
),
554+
"rmsprop": (
555+
lib.crmsprop32bit_grad_32,
556+
lib.crmsprop32bit_grad_16,
557+
),
558+
"lion": (
559+
lib.clion32bit_grad_fp32,
560+
lib.clion32bit_grad_fp16,
561+
lib.clion32bit_grad_bf16,
562+
),
563+
"adagrad": (
564+
lib.cadagrad32bit_grad_32,
565+
lib.cadagrad32bit_grad_16,
566+
),
567+
"lamb": (
568+
lib.cadam32bit_grad_fp32,
569+
lib.cadam32bit_grad_fp16,
570+
lib.cadam32bit_grad_bf16,
571+
),
572+
"ademamix": (
573+
lib.cademamix32bit_grad_fp32,
574+
lib.cademamix32bit_grad_fp16,
575+
lib.cademamix32bit_grad_bf16,
576+
),
577+
}
578+
579+
str2optimizer8bit_blockwise = {
580+
"adam": (
581+
lib.cadam_8bit_blockwise_grad_fp32,
582+
lib.cadam_8bit_blockwise_grad_fp16,
583+
lib.cadam_8bit_blockwise_grad_bf16,
584+
),
585+
"momentum": (
586+
lib.cmomentum_8bit_blockwise_grad_fp32,
587+
lib.cmomentum_8bit_blockwise_grad_fp16,
588+
lib.cmomentum_8bit_blockwise_grad_bf16,
589+
),
590+
"rmsprop": (
591+
lib.crmsprop_8bit_blockwise_grad_fp32,
592+
lib.crmsprop_8bit_blockwise_grad_fp16,
593+
lib.crmsprop_8bit_blockwise_grad_bf16,
594+
),
595+
"lion": (
596+
lib.clion_8bit_blockwise_grad_fp32,
597+
lib.clion_8bit_blockwise_grad_fp16,
598+
lib.clion_8bit_blockwise_grad_bf16,
599+
),
600+
"adagrad": (
601+
lib.cadagrad_8bit_blockwise_grad_fp32,
602+
lib.cadagrad_8bit_blockwise_grad_fp16,
603+
lib.cadagrad_8bit_blockwise_grad_bf16,
604+
),
605+
"ademamix": (
606+
lib.cademamix_8bit_blockwise_grad_fp32,
607+
lib.cademamix_8bit_blockwise_grad_fp16,
608+
lib.cademamix_8bit_blockwise_grad_bf16,
609+
),
610+
}
611+
612+
613+
def _optimizer_update_32bit_impl(
614+
optimizer_name: str,
615+
g: torch.Tensor,
616+
p: torch.Tensor,
617+
state1: torch.Tensor,
618+
state2: Optional[torch.Tensor],
619+
unorm_vec: Optional[torch.Tensor],
620+
max_unorm: float,
621+
param_norm: float,
622+
beta1: float,
623+
beta2: float,
624+
beta3: float,
625+
alpha: float,
626+
eps: float,
627+
weight_decay: float,
628+
step: int,
629+
lr: float,
630+
gnorm_scale: float,
631+
skip_zeros=False,
632+
) -> None:
633+
optim_fns = str2optimizer32bit.get(optimizer_name, None)
634+
if optim_fns is None:
635+
raise ValueError(
636+
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
637+
)
638+
if g.dtype == torch.float32:
639+
optim_func = optim_fns[0]
640+
elif g.dtype == torch.float16:
641+
optim_func = optim_fns[1]
642+
elif g.dtype == torch.bfloat16 and len(optim_fns) == 3:
643+
optim_func = optim_fns[2]
644+
else:
645+
raise ValueError(
646+
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
647+
)
648+
649+
with _cuda_device_of(g):
650+
optim_func(
651+
get_ptr(g),
652+
get_ptr(p),
653+
get_ptr(state1),
654+
get_ptr(state2),
655+
get_ptr(unorm_vec),
656+
ct.c_float(max_unorm),
657+
ct.c_float(param_norm),
658+
ct.c_float(beta1),
659+
ct.c_float(beta2),
660+
ct.c_float(beta3),
661+
ct.c_float(alpha),
662+
ct.c_float(eps),
663+
ct.c_float(weight_decay),
664+
ct.c_int32(step),
665+
ct.c_float(lr),
666+
ct.c_float(gnorm_scale),
667+
ct.c_bool(skip_zeros),
668+
ct.c_int32(g.numel()),
669+
)
670+
671+
672+
def _optimizer_update_8bit_blockwise_impl(
673+
optimizer_name: str,
674+
g: torch.Tensor,
675+
p: torch.Tensor,
676+
state1: torch.Tensor,
677+
state2: Optional[torch.Tensor],
678+
beta1: float,
679+
beta2: float,
680+
beta3: float,
681+
alpha: float,
682+
eps: float,
683+
step: int,
684+
lr: float,
685+
qmap1: torch.Tensor,
686+
qmap2: Optional[torch.Tensor],
687+
absmax1: torch.Tensor,
688+
absmax2: Optional[torch.Tensor],
689+
weight_decay: float,
690+
gnorm_scale: float,
691+
skip_zeros=False,
692+
) -> None:
693+
# torch._check(
694+
# g.numel() == p.numel(),
695+
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
696+
# )
697+
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
698+
699+
# torch._check(
700+
# g.dtype in compute_dtypes,
701+
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
702+
# )
703+
# torch._check(
704+
# g.dtype == p.dtype,
705+
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
706+
# )
707+
# torch._check(
708+
# state1.dtype == torch.uint8,
709+
# lambda: f"state1 must be uint8, got {state1.dtype}",
710+
# )
711+
# torch._check(
712+
# qmap1.dtype == absmax1.dtype == torch.float32,
713+
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
714+
# )
715+
# if state2 is not None:
716+
# torch._check(
717+
# state2.dtype == torch.uint8,
718+
# lambda: f"state2 must be uint8, got {state2.dtype}",
719+
# )
720+
# torch._check(
721+
# qmap2.dtype == absmax2.dtype == torch.float32,
722+
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
723+
# )
724+
optimizer_fns = str2optimizer8bit_blockwise.get(optimizer_name)
725+
if optimizer_fns is None:
726+
raise ValueError(
727+
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
728+
)
729+
730+
if g.dtype == torch.float32:
731+
optimizer_fn = optimizer_fns[0]
732+
elif g.dtype == torch.float16:
733+
optimizer_fn = optimizer_fns[1]
734+
elif g.dtype == torch.bfloat16:
735+
optimizer_fn = optimizer_fns[2]
736+
else:
737+
raise ValueError(
738+
f"Unsupported gradient dtype: {g.dtype}. Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
739+
)
740+
741+
with _cuda_device_of(g):
742+
optimizer_fn(
743+
get_ptr(p),
744+
get_ptr(g),
745+
get_ptr(state1),
746+
get_ptr(state2),
747+
ct.c_float(beta1),
748+
ct.c_float(beta2),
749+
ct.c_float(beta3),
750+
ct.c_float(alpha),
751+
ct.c_float(eps),
752+
ct.c_int32(step),
753+
ct.c_float(lr),
754+
get_ptr(qmap1),
755+
get_ptr(qmap2),
756+
get_ptr(absmax1),
757+
get_ptr(absmax2),
758+
ct.c_float(weight_decay),
759+
ct.c_float(gnorm_scale),
760+
ct.c_bool(skip_zeros),
761+
ct.c_int32(g.numel()),
762+
)
763+
764+
765+
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl)
766+
register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl)

0 commit comments

Comments
 (0)