@@ -352,7 +352,7 @@ def _(
352352
353353torch .library .define (
354354 "bitsandbytes::optimizer_update_32bit" ,
355- "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros) -> ()" ,
355+ "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False ) -> ()" ,
356356)
357357
358358
@@ -395,7 +395,7 @@ def _(
395395
396396torch .library .define (
397397 "bitsandbytes::optimizer_update_8bit_blockwise" ,
398- "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()" ,
398+ "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False ) -> ()" ,
399399)
400400
401401
@@ -417,8 +417,8 @@ def _(
417417 qmap2 : Optional [torch .Tensor ],
418418 absmax1 : torch .Tensor ,
419419 absmax2 : Optional [torch .Tensor ],
420- weight_decay : float = 0.0 ,
421- gnorm_scale : float = 1.0 ,
420+ weight_decay : float ,
421+ gnorm_scale : float ,
422422 skip_zeros = False ,
423423) -> None :
424424 torch ._check (
0 commit comments