Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 8dda5e7

Browse files
committed
Force conversion of LR variables to float due to type coercion from manager serialization (#384)
1 parent 47d9472 commit 8dda5e7

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/sparseml/optim/learning_rate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ def validate_lr_info(self):
156156
else:
157157
raise ValueError("unknown lr_class given of {}".format(self._lr_class))
158158

159+
if isinstance(self._init_lr, str):
160+
self._init_lr = float(self._init_lr)
161+
159162
if self._init_lr <= 0.0:
160163
raise ValueError("init_lr must be greater than 0")
161164

src/sparseml/pytorch/optim/modifier_lr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,9 @@ def validate(self):
414414
if self.lr_func not in lr_funcs:
415415
raise ValueError(f"lr_func must be one of {lr_funcs}")
416416

417+
if isinstance(self.init_lr, str):
418+
self.init_lr = float(self.init_lr)
419+
417420
if (
418421
(not self.init_lr and self.init_lr != 0)
419422
or self.init_lr < 0.0
@@ -423,6 +426,9 @@ def validate(self):
423426
f"init_lr must be within range [0.0, 1.0], given {self.init_lr}"
424427
)
425428

429+
if isinstance(self.final_lr, str):
430+
self.final_lr = float(self.final_lr)
431+
426432
if (
427433
(not self.final_lr and self.final_lr != 0)
428434
or self.final_lr < 0.0

0 commit comments

Comments
 (0)