Skip to content

Commit 3dc9a43

Browse files
Merge pull request #3898 from R-N/lr-comma
Allow trailing comma in learning rate
2 parents 5612d03 + ef4c94e commit 3dc9a43

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

modules/textual_inversion/learn_schedule.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,37 @@
44
class LearnScheduleIterator:
55
def __init__(self, learn_rate, max_steps, cur_step=0):
66
"""
7-
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
7+
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
88
"""
99

1010
pairs = learn_rate.split(',')
1111
self.rates = []
1212
self.it = 0
1313
self.maxit = 0
14-
for i, pair in enumerate(pairs):
15-
tmp = pair.split(':')
16-
if len(tmp) == 2:
17-
step = int(tmp[1])
18-
if step > cur_step:
19-
self.rates.append((float(tmp[0]), min(step, max_steps)))
20-
self.maxit += 1
21-
if step > max_steps:
14+
try:
15+
for i, pair in enumerate(pairs):
16+
if not pair.strip():
17+
continue
18+
tmp = pair.split(':')
19+
if len(tmp) == 2:
20+
step = int(tmp[1])
21+
if step > cur_step:
22+
self.rates.append((float(tmp[0]), min(step, max_steps)))
23+
self.maxit += 1
24+
if step > max_steps:
25+
return
26+
elif step == -1:
27+
self.rates.append((float(tmp[0]), max_steps))
28+
self.maxit += 1
2229
return
23-
elif step == -1:
30+
else:
2431
self.rates.append((float(tmp[0]), max_steps))
2532
self.maxit += 1
2633
return
27-
else:
28-
self.rates.append((float(tmp[0]), max_steps))
29-
self.maxit += 1
30-
return
34+
assert self.rates
35+
except (ValueError, AssertionError):
36+
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
37+
3138

3239
def __iter__(self):
3340
return self

0 commit comments

Comments
 (0)