Skip to content

Commit a5f3adb

Browse files
committed
Allow trailing comma in learning rate
1 parent 35c45df commit a5f3adb

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

modules/textual_inversion/learn_schedule.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,30 @@ def __init__(self, learn_rate, max_steps, cur_step=0):
1111
self.rates = []
1212
self.it = 0
1313
self.maxit = 0
14-
for i, pair in enumerate(pairs):
15-
tmp = pair.split(':')
16-
if len(tmp) == 2:
17-
step = int(tmp[1])
18-
if step > cur_step:
19-
self.rates.append((float(tmp[0]), min(step, max_steps)))
20-
self.maxit += 1
21-
if step > max_steps:
14+
try:
15+
for i, pair in enumerate(pairs):
16+
if not pair.strip():
17+
continue
18+
tmp = pair.split(':')
19+
if len(tmp) == 2:
20+
step = int(tmp[1])
21+
if step > cur_step:
22+
self.rates.append((float(tmp[0]), min(step, max_steps)))
23+
self.maxit += 1
24+
if step > max_steps:
25+
return
26+
elif step == -1:
27+
self.rates.append((float(tmp[0]), max_steps))
28+
self.maxit += 1
2229
return
23-
elif step == -1:
30+
else:
2431
self.rates.append((float(tmp[0]), max_steps))
2532
self.maxit += 1
2633
return
27-
else:
28-
self.rates.append((float(tmp[0]), max_steps))
29-
self.maxit += 1
30-
return
34+
assert self.rates
35+
except (ValueError, AssertionError):
36+
raise Exception("Invalid learning rate schedule")
37+
3138

3239
def __iter__(self):
3340
return self

0 commit comments

Comments
 (0)