Skip to content

Commit 0f6fe6b

Browse files
committed
Fixed default args
1 parent e33ba1c commit 0f6fe6b

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

bitsandbytes/_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def _(
352352

353353
torch.library.define(
354354
"bitsandbytes::optimizer_update_32bit",
355-
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros) -> ()",
355+
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()",
356356
)
357357

358358

@@ -395,7 +395,7 @@ def _(
395395

396396
torch.library.define(
397397
"bitsandbytes::optimizer_update_8bit_blockwise",
398-
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()",
398+
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()",
399399
)
400400

401401

@@ -417,8 +417,8 @@ def _(
417417
qmap2: Optional[torch.Tensor],
418418
absmax1: torch.Tensor,
419419
absmax2: Optional[torch.Tensor],
420-
weight_decay: float = 0.0,
421-
gnorm_scale: float = 1.0,
420+
weight_decay: float,
421+
gnorm_scale: float,
422422
skip_zeros=False,
423423
) -> None:
424424
torch._check(

bitsandbytes/backends/cuda/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,8 +686,8 @@ def _optimizer_update_8bit_blockwise_impl(
686686
qmap2: Optional[torch.Tensor],
687687
absmax1: torch.Tensor,
688688
absmax2: Optional[torch.Tensor],
689-
weight_decay: float = 0.0,
690-
gnorm_scale: float = 1.0,
689+
weight_decay: float,
690+
gnorm_scale: float,
691691
skip_zeros=False,
692692
) -> None:
693693
# torch._check(

0 commit comments

Comments
 (0)