Skip to content

Commit e33ba1c

Browse files
committed
Added mutated args to the schema
1 parent 24d9139 commit e33ba1c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

bitsandbytes/_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def _(
352352

353353
torch.library.define(
354354
"bitsandbytes::optimizer_update_32bit",
355-
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, Tensor! unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros) -> ()",
355+
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros) -> ()",
356356
)
357357

358358

@@ -395,7 +395,7 @@ def _(
395395

396396
torch.library.define(
397397
"bitsandbytes::optimizer_update_8bit_blockwise",
398-
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor qmap1, Tensor! qmap2, Tensor absmax1, Tensor! absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()",
398+
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()",
399399
)
400400

401401

0 commit comments

Comments
 (0)