Skip to content

Commit 3f8da24

Browse files
committed
feat(lr): auto-adjust decay_steps instead of raising error
When decay_steps exceeds the decay phase (num_steps - warmup_steps) and decay_rate is not explicitly provided, automatically adjust decay_steps to a sensible default (capped at 100, or decay_total//100 + 1) instead of raising ValueError. This makes the learning rate scheduler more user-friendly by gracefully handling misconfigured decay_steps values. Changes: - LearningRateExp: auto-adjust decay_steps when >= decay_total - Update argcheck and training-advanced.md documentation - Update pd/pt/tf test_lr.py to use auto-adjusted decay_steps - Remove obsolete validation tests in test_learning_rate.py - Fix tf test dtype: float32 -> float64
1 parent 2e11654 commit 3f8da24

File tree

7 files changed

+31
-82
lines changed

7 files changed

+31
-82
lines changed

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,6 @@ def __init__(
288288
If both stop_lr and stop_lr_ratio are provided, or neither is provided.
289289
If both warmup_steps and warmup_ratio are provided.
290290
If decay_steps is not positive.
291-
If decay_steps is larger than the decay phase total steps when decay_rate is not provided.
292291
"""
293292
super().__init__(
294293
start_lr=start_lr,
@@ -307,12 +306,12 @@ def __init__(
307306

308307
if self.decay_steps <= 0:
309308
raise ValueError(f"decay_steps ({self.decay_steps}) must be positive.")
310-
# Only validate decay_steps <= decay_total when computing decay_rate from start_lr/stop_lr
311-
if decay_rate is None and self.decay_steps > decay_total:
312-
raise ValueError(
313-
f"decay_steps ({self.decay_steps}) must not exceed decay phase steps ({decay_total}) "
314-
"when decay_rate is not explicitly provided."
315-
)
309+
310+
# Auto-adjust decay_steps if it exceeds decay_total and decay_rate is not provided
311+
if decay_rate is None and self.decay_steps >= decay_total:
312+
# Compute sensible default: cap at 100, but ensure at least 1 for small decay_total
313+
default_ds = 100 if decay_total // 10 > 100 else decay_total // 100 + 1
314+
self.decay_steps = default_ds
316315

317316
# Avoid log(0) issues by clamping stop_lr for computation
318317
clamped_stop_lr = max(self.stop_lr, 1e-10)

deepmd/utils/argcheck.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2669,7 +2669,10 @@ def learning_rate_exp() -> list[Argument]:
26692669
"Mutually exclusive with stop_lr_ratio."
26702670
)
26712671
doc_decay_steps = (
2672-
"The learning rate is decaying every this number of training steps."
2672+
"The learning rate is decaying every this number of training steps. "
2673+
"If decay_steps exceeds the decay phase steps (num_steps - warmup_steps) "
2674+
"and decay_rate is not provided, it will be automatically adjusted to a "
2675+
"sensible default value."
26732676
)
26742677
doc_decay_rate = (
26752678
"The decay rate for the learning rate. "

doc/train/training-advanced.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ The {ref}`learning_rate <learning_rate>` section for exponential decay in `input
8181

8282
**Additional parameters for `exp` type only:**
8383

84-
- {ref}`decay_steps <learning_rate[exp]/decay_steps>` specifies the interval (in training steps) at which the learning rate is decayed. The learning rate is updated every {ref}`decay_steps <learning_rate[exp]/decay_steps>` steps during the decay phase.
84+
- {ref}`decay_steps <learning_rate[exp]/decay_steps>` specifies the interval (in training steps) at which the learning rate is decayed. The learning rate is updated every {ref}`decay_steps <learning_rate[exp]/decay_steps>` steps during the decay phase. If `decay_steps` exceeds the decay phase steps (num_steps - warmup_steps) and `decay_rate` is not explicitly provided, it will be automatically adjusted to a sensible default value.
8585
- {ref}`smooth <learning_rate[exp]/smooth>` (optional, default: `false`) controls the decay behavior. When set to `false`, the learning rate decays in a stepped manner (updated every `decay_steps` steps). When set to `true`, the learning rate decays smoothly at every step.
8686

8787
**Learning rate formula for `exp` type:**

source/tests/pd/test_lr.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class TestLearningRate(unittest.TestCase):
1818
def setUp(self):
1919
self.start_lr = 0.001
2020
self.stop_lr = 3.51e-8
21-
# decay_steps must not exceed num_steps
21+
# decay_steps will be auto-adjusted if >= num_steps
2222
self.decay_steps = np.arange(400, 501, 100)
2323
self.num_steps = np.arange(500, 1600, 500)
2424

@@ -72,44 +72,40 @@ def decay_rate_pd(self):
7272
num_steps=self.stop_step,
7373
)
7474

75-
default_ds = 100 if self.stop_step // 10 > 100 else self.stop_step // 100 + 1
76-
# Use local variable to avoid modifying instance state
77-
decay_step_for_rate = self.decay_step
78-
if decay_step_for_rate >= self.stop_step:
79-
decay_step_for_rate = default_ds
75+
# Use the auto-adjusted decay_steps from my_lr for consistency
76+
actual_decay_steps = my_lr.decay_steps
8077
decay_rate = np.exp(
81-
np.log(self.stop_lr / self.start_lr)
82-
/ (self.stop_step / decay_step_for_rate)
78+
np.log(self.stop_lr / self.start_lr) / (self.stop_step / actual_decay_steps)
8379
)
8480
my_lr_decay = LearningRateExp(
8581
start_lr=self.start_lr,
8682
stop_lr=1e-10,
87-
decay_steps=self.decay_step,
83+
decay_steps=actual_decay_steps,
8884
num_steps=self.stop_step,
8985
decay_rate=decay_rate,
9086
)
9187
min_lr = 1e-5
9288
my_lr_decay_trunc = LearningRateExp(
9389
start_lr=self.start_lr,
9490
stop_lr=min_lr,
95-
decay_steps=self.decay_step,
91+
decay_steps=actual_decay_steps,
9692
num_steps=self.stop_step,
9793
decay_rate=decay_rate,
9894
)
9995
my_vals = [
10096
my_lr.value(step_id)
10197
for step_id in range(self.stop_step)
102-
if step_id % self.decay_step != 0
98+
if step_id % actual_decay_steps != 0
10399
]
104100
my_vals_decay = [
105101
my_lr_decay.value(step_id)
106102
for step_id in range(self.stop_step)
107-
if step_id % self.decay_step != 0
103+
if step_id % actual_decay_steps != 0
108104
]
109105
my_vals_decay_trunc = [
110106
my_lr_decay_trunc.value(step_id)
111107
for step_id in range(self.stop_step)
112-
if step_id % self.decay_step != 0
108+
if step_id % actual_decay_steps != 0
113109
]
114110
self.assertTrue(np.allclose(my_vals_decay, my_vals))
115111
self.assertTrue(

source/tests/pt/test_lr.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TestLearningRate(unittest.TestCase):
1919
def setUp(self) -> None:
2020
self.start_lr = 0.001
2121
self.stop_lr = 3.51e-8
22-
# decay_steps must not exceed num_steps
22+
# decay_steps will be auto-adjusted if >= num_steps
2323
self.decay_steps = np.arange(400, 501, 100)
2424
self.num_steps = np.arange(500, 1600, 500)
2525

@@ -73,44 +73,40 @@ def decay_rate_pt(self) -> None:
7373
num_steps=self.stop_step,
7474
)
7575

76-
default_ds = 100 if self.stop_step // 10 > 100 else self.stop_step // 100 + 1
77-
# Use local variable to avoid modifying instance state
78-
decay_step_for_rate = self.decay_step
79-
if decay_step_for_rate >= self.stop_step:
80-
decay_step_for_rate = default_ds
76+
# Use the auto-adjusted decay_steps from my_lr for consistency
77+
actual_decay_steps = my_lr.decay_steps
8178
decay_rate = np.exp(
82-
np.log(self.stop_lr / self.start_lr)
83-
/ (self.stop_step / decay_step_for_rate)
79+
np.log(self.stop_lr / self.start_lr) / (self.stop_step / actual_decay_steps)
8480
)
8581
my_lr_decay = LearningRateExp(
8682
start_lr=self.start_lr,
8783
stop_lr=1e-10,
88-
decay_steps=self.decay_step,
84+
decay_steps=actual_decay_steps,
8985
num_steps=self.stop_step,
9086
decay_rate=decay_rate,
9187
)
9288
min_lr = 1e-5
9389
my_lr_decay_trunc = LearningRateExp(
9490
start_lr=self.start_lr,
9591
stop_lr=min_lr,
96-
decay_steps=self.decay_step,
92+
decay_steps=actual_decay_steps,
9793
num_steps=self.stop_step,
9894
decay_rate=decay_rate,
9995
)
10096
my_vals = [
10197
my_lr.value(step_id)
10298
for step_id in range(self.stop_step)
103-
if step_id % self.decay_step != 0
99+
if step_id % actual_decay_steps != 0
104100
]
105101
my_vals_decay = [
106102
my_lr_decay.value(step_id)
107103
for step_id in range(self.stop_step)
108-
if step_id % self.decay_step != 0
104+
if step_id % actual_decay_steps != 0
109105
]
110106
my_vals_decay_trunc = [
111107
my_lr_decay_trunc.value(step_id)
112108
for step_id in range(self.stop_step)
113-
if step_id % self.decay_step != 0
109+
if step_id % actual_decay_steps != 0
114110
]
115111
self.assertTrue(np.allclose(my_vals_decay, my_vals))
116112
self.assertTrue(

source/tests/tf/test_lr.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@
2323
class TestLearningRateScheduleValidation(unittest.TestCase):
2424
"""Test TF wrapper validation and error handling."""
2525

26-
def test_missing_start_lr(self) -> None:
27-
"""Test that missing start_lr raises ValueError."""
28-
with self.assertRaises(ValueError) as cm:
29-
LearningRateSchedule({"type": "exp", "stop_lr": 1e-5})
30-
self.assertIn("start_lr", str(cm.exception))
31-
3226
def test_value_before_build(self) -> None:
3327
"""Test that calling value() before build() raises RuntimeError."""
3428
lr_schedule = LearningRateSchedule({"start_lr": 1e-3})
@@ -48,13 +42,13 @@ class TestLearningRateScheduleBuild(unittest.TestCase):
4842
"""Test TF tensor building and integration."""
4943

5044
def test_build_returns_tensor(self) -> None:
51-
"""Test that build() returns a float32 TF tensor."""
45+
"""Test that build() returns a float64 TF tensor."""
5246
lr_schedule = LearningRateSchedule({"start_lr": 1e-3, "stop_lr": 1e-5})
5347
global_step = tf.constant(0, dtype=tf.int64)
5448
lr_tensor = lr_schedule.build(global_step, num_steps=10000)
5549

5650
self.assertIsInstance(lr_tensor, tf.Tensor)
57-
self.assertEqual(lr_tensor.dtype, tf.float32)
51+
self.assertEqual(lr_tensor.dtype, tf.float64)
5852

5953
def test_default_type_exp(self) -> None:
6054
"""Test that default type is 'exp' when not specified."""

source/tests/universal/dpmodel/utils/test_learning_rate.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -199,42 +199,3 @@ def test_cosine_beyond_num_steps(self) -> None:
199199
num_steps=10000,
200200
)
201201
np.testing.assert_allclose(lr.value(20000), 1e-5, rtol=1e-10)
202-
203-
204-
class TestLearningRateValidation(unittest.TestCase):
205-
"""Test learning rate parameter validation."""
206-
207-
def test_decay_steps_exceeds_decay_total_without_warmup(self) -> None:
208-
"""Test that decay_steps > num_steps raises ValueError."""
209-
with self.assertRaises(ValueError) as cm:
210-
LearningRateExp(
211-
start_lr=1e-3,
212-
stop_lr=1e-5,
213-
num_steps=500,
214-
decay_steps=600,
215-
)
216-
self.assertIn("decay_steps", str(cm.exception))
217-
self.assertIn("exceed", str(cm.exception))
218-
219-
def test_decay_steps_exceeds_decay_total_with_warmup(self) -> None:
220-
"""Test that decay_steps > (num_steps - warmup_steps) raises ValueError."""
221-
with self.assertRaises(ValueError) as cm:
222-
LearningRateExp(
223-
start_lr=1e-3,
224-
stop_lr=1e-5,
225-
num_steps=1000,
226-
decay_steps=900,
227-
warmup_steps=200, # decay_total = 800
228-
)
229-
self.assertIn("decay_steps", str(cm.exception))
230-
231-
def test_decay_steps_equals_decay_total_allowed(self) -> None:
232-
"""Test that decay_steps == decay_total is allowed (boundary case)."""
233-
# Should not raise
234-
lr = LearningRateExp(
235-
start_lr=1e-3,
236-
stop_lr=1e-5,
237-
num_steps=500,
238-
decay_steps=500,
239-
)
240-
self.assertEqual(lr.decay_steps, 500)

0 commit comments

Comments
 (0)