Skip to content

Commit c9e1908

Browse files
committed
Merge branch 'main' into egor/8bit_opt2
2 parents cc68b22 + 14147f6 commit c9e1908

File tree

4 files changed

+11
-7
lines changed

4 files changed

+11
-7
lines changed

bitsandbytes/_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def _(
352352

353353
torch.library.define(
354354
"bitsandbytes::optimizer_update_32bit",
355-
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, Tensor! unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros) -> ()",
355+
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()",
356356
)
357357

358358

@@ -395,7 +395,7 @@ def _(
395395

396396
torch.library.define(
397397
"bitsandbytes::optimizer_update_8bit_blockwise",
398-
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor qmap1, Tensor! qmap2, Tensor absmax1, Tensor! absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()",
398+
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()",
399399
)
400400

401401

@@ -417,8 +417,8 @@ def _(
417417
qmap2: Optional[torch.Tensor],
418418
absmax1: torch.Tensor,
419419
absmax2: Optional[torch.Tensor],
420-
weight_decay: float = 0.0,
421-
gnorm_scale: float = 1.0,
420+
weight_decay: float,
421+
gnorm_scale: float,
422422
skip_zeros=False,
423423
) -> None:
424424
torch._check(

bitsandbytes/backends/cuda/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,8 +686,8 @@ def _optimizer_update_8bit_blockwise_impl(
686686
qmap2: Optional[torch.Tensor],
687687
absmax1: torch.Tensor,
688688
absmax2: Optional[torch.Tensor],
689-
weight_decay: float = 0.0,
690-
gnorm_scale: float = 1.0,
689+
weight_decay: float,
690+
gnorm_scale: float,
691691
skip_zeros=False,
692692
) -> None:
693693
# torch._check(

tests/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
def get_available_devices(no_cpu=False):
2222
if "BNB_TEST_DEVICE" in os.environ:
2323
# If the environment variable is set, use it directly.
24-
return [d for d in os.environ["BNB_TEST_DEVICE"] if d.lower() != "cpu"]
24+
device = os.environ["BNB_TEST_DEVICE"]
25+
return [] if no_cpu and device == "cpu" else [device]
2526

2627
devices = [] if HIP_ENVIRONMENT else ["cpu"] if not no_cpu else []
2728

tests/test_optim.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def rm_path(path):
170170
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
171171
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
172172
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
173+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
173174
def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
174175
if optim_name.startswith("paged_") and sys.platform == "win32":
175176
pytest.skip("Paged optimizers can have issues on Windows.")
@@ -250,6 +251,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
250251
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
251252
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
252253
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
254+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
253255
def test_global_config(dim1, dim2, gtype, device):
254256
if dim1 == 1 and dim2 == 1:
255257
return
@@ -306,6 +308,7 @@ def test_global_config(dim1, dim2, gtype, device):
306308
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
307309
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
308310
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
311+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
309312
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
310313
torch.set_printoptions(precision=6)
311314

0 commit comments

Comments
 (0)