Skip to content

Commit 89cdf56

Browse files
committed
refactor: unify learning rate schedulers with array API
- Refactor BaseLR in dpmodel to use array_api_compat for backend-agnostic implementation - Consolidate learning rate logic from TF/PT/PD backends into unified dpmodel layer - Use array API operations (xp.where, xp.clip, etc.) for JIT compatibility across backends - Add warmup support (warmup_steps, warmup_ratio, warmup_start_factor) during refactoring - Add stop_ratio parameter as alternative to stop_lr for flexible configuration - Implement mutual exclusion validation for stop_lr/stop_ratio and warmup_steps/warmup_ratio - Update all backends to use unified BaseLR implementation - Add comprehensive consistency tests across NumPy/PyTorch/JAX/array_api_strict backends
1 parent 2a9667e commit 89cdf56

File tree

21 files changed

+1033
-300
lines changed

21 files changed

+1033
-300
lines changed

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 262 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -29,86 +29,251 @@ def __new__(cls: type, *args: Any, **kwargs: Any) -> Any:
2929
return super().__new__(cls)
3030

3131
def __init__(
32-
self, start_lr: float, stop_lr: float, stop_steps: int, **kwargs: Any
32+
self,
33+
start_lr: float,
34+
stop_lr: float | None = None,
35+
stop_ratio: float | None = None,
36+
stop_steps: int = 100000,
37+
warmup_steps: int = 0,
38+
warmup_ratio: float | None = None,
39+
warmup_start_factor: float = 0.0,
40+
**kwargs: Any,
3341
) -> None:
3442
"""
35-
Base class for learning rate schedules.
43+
Base class for learning rate schedules with warmup support.
3644
3745
Parameters
3846
----------
39-
start_lr
40-
The initial learning rate.
41-
stop_lr
42-
The final learning rate.
43-
stop_steps
44-
The total training steps for learning rate scheduler.
47+
start_lr : float
48+
The learning rate at the start of the training (after warmup).
49+
stop_lr : float, optional
50+
The final learning rate at the end of the training.
51+
Mutually exclusive with stop_ratio.
52+
stop_ratio : float, optional
53+
The ratio of stop_lr to start_lr. stop_lr = start_lr * stop_ratio.
54+
Mutually exclusive with stop_lr.
55+
One of stop_lr or stop_ratio must be provided.
56+
stop_steps : int
57+
The total training steps (including warmup).
58+
warmup_steps : int, optional
59+
The number of steps for learning rate warmup.
60+
Mutually exclusive with warmup_ratio. Default is 0 (no warmup).
61+
warmup_ratio : float, optional
62+
The ratio of warmup steps to total training steps.
63+
warmup_steps = int(warmup_ratio * stop_steps).
64+
Mutually exclusive with warmup_steps.
65+
warmup_start_factor : float, optional
66+
The factor of start_lr for the initial warmup learning rate.
67+
The warmup learning rate starts from warmup_start_factor * start_lr.
68+
Default is 0.0.
4569
"""
70+
# === Step 1. Compute stop_lr from stop_ratio if needed ===
71+
# Mutual exclusion validated in argcheck.py
72+
if stop_ratio is not None:
73+
self.stop_lr = start_lr * stop_ratio
74+
else:
75+
self.stop_lr = stop_lr # type: ignore[assignment]
76+
77+
# === Step 2. Compute warmup_steps from warmup_ratio if needed ===
78+
# Mutual exclusion validated in argcheck.py
79+
if warmup_ratio is not None:
80+
self.warmup_steps = int(warmup_ratio * stop_steps)
81+
else:
82+
self.warmup_steps = warmup_steps
83+
84+
# === Step 3. Validate step ranges (runtime check) ===
85+
if stop_steps <= 0:
86+
raise ValueError("stop_steps must be positive")
87+
if self.warmup_steps < 0:
88+
raise ValueError("warmup_steps must be non-negative")
89+
if self.warmup_steps >= stop_steps:
90+
raise ValueError("warmup_steps must be smaller than stop_steps")
91+
92+
# === Step 4. Compute warmup_start_lr ===
93+
self.warmup_start_lr = warmup_start_factor * start_lr
94+
95+
# === Step 5. Store core parameters ===
4696
self.start_lr = start_lr
47-
self.stop_lr = stop_lr
4897
self.stop_steps = stop_steps
98+
# Decay phase covers (stop_steps - warmup_steps) steps
99+
self.decay_stop_steps = stop_steps - self.warmup_steps
49100

50101
@abstractmethod
51-
def value(self, step: int | Array) -> Array:
52-
"""Get the learning rate at the given step."""
53-
# in optax, step will be a jnp.ndarray passed in JIT mode
102+
def _decay_value(self, step: int | Array) -> Array:
103+
"""
104+
Get the decayed learning rate at the given step (after warmup).
105+
106+
This method should implement the actual decay logic (exp, cosine, etc.)
107+
without considering warmup.
108+
109+
Parameters
110+
----------
111+
step : int or Array
112+
The step index relative to the end of warmup.
113+
For example, if warmup_steps=100 and total_step=150, this method
114+
will be called with step=50.
115+
116+
Returns
117+
-------
118+
Array
119+
The decayed learning rate (absolute value, not factor).
120+
"""
54121
pass
55122

123+
def value(self, step: int | Array) -> Array | float:
124+
"""
125+
Get the learning rate at the given step, including warmup.
126+
127+
Parameters
128+
----------
129+
step : int or Array
130+
The absolute step index from the start of training.
131+
132+
Returns
133+
-------
134+
Array
135+
The learning rate at the given step.
136+
"""
137+
is_scalar = isinstance(step, (int, float))
138+
if not array_api_compat.is_array_api_obj(step):
139+
step = np.asarray(step)
140+
xp = array_api_compat.array_namespace(step)
141+
142+
# === Step 1. Handle no-warmup case directly ===
143+
if self.warmup_steps == 0:
144+
lr = self._decay_value(xp.astype(step, xp.float64))
145+
else:
146+
# === Step 2. Warmup phase ===
147+
# Linear warmup from warmup_start_lr to start_lr
148+
warmup_progress = xp.astype(step, xp.float64) / self.warmup_steps
149+
warmup_lr = (
150+
self.warmup_start_lr
151+
+ (self.start_lr - self.warmup_start_lr) * warmup_progress
152+
)
153+
154+
# === Step 3. Decay phase ===
155+
# Call subclass decay logic for steps after warmup
156+
decay_step = xp.maximum(
157+
xp.astype(step, xp.float64) - self.warmup_steps, 0.0
158+
)
159+
decay_lr = self._decay_value(decay_step)
160+
161+
# === Step 4. Select warmup or decay based on step ===
162+
lr = xp.where(step < self.warmup_steps, warmup_lr, decay_lr)
163+
164+
if is_scalar:
165+
return float(lr)
166+
return lr
167+
56168

57169
@BaseLR.register("exp")
58170
class LearningRateExp(BaseLR):
59171
def __init__(
60172
self,
61173
start_lr: float,
62-
stop_lr: float,
63-
decay_steps: int,
64-
stop_steps: int,
174+
stop_lr: float | None = None,
175+
stop_ratio: float | None = None,
176+
decay_steps: int = 5000,
177+
stop_steps: int = 100000,
65178
decay_rate: float | None = None,
179+
warmup_steps: int = 0,
180+
warmup_ratio: float | None = None,
181+
warmup_start_factor: float = 0.0,
66182
**kwargs: Any,
67183
) -> None:
68184
"""
69-
Construct an exponential-decayed learning rate.
185+
Construct an exponential-decayed learning rate with optional warmup.
70186
71187
Parameters
72188
----------
73-
start_lr
74-
The learning rate at the start of the training.
75-
stop_lr
189+
start_lr : float
190+
The learning rate at the start of the training (after warmup).
191+
stop_lr : float, optional
76192
The desired learning rate at the end of the training.
77193
When decay_rate is explicitly set, this value will serve as
78-
the minimum learning rate during training. In other words,
79-
if the learning rate decays below stop_lr, stop_lr will be applied instead.
80-
decay_steps
194+
the minimum learning rate during training.
195+
Mutually exclusive with stop_ratio.
196+
stop_ratio : float, optional
197+
The ratio of stop_lr to start_lr.
198+
Mutually exclusive with stop_lr.
199+
decay_steps : int
81200
The learning rate is decaying every this number of training steps.
82-
stop_steps
83-
The total training steps for learning rate scheduler.
84-
decay_rate
201+
Default is 5000.
202+
stop_steps : int
203+
The total training steps (including warmup).
204+
decay_rate : float, optional
85205
The decay rate for the learning rate.
86206
If provided, the decay rate will be set instead of
87207
calculating it through interpolation between start_lr and stop_lr.
208+
warmup_steps : int, optional
209+
The number of steps for learning rate warmup.
210+
Mutually exclusive with warmup_ratio. Default is 0.
211+
warmup_ratio : float, optional
212+
The ratio of warmup steps to total training steps.
213+
Mutually exclusive with warmup_steps.
214+
warmup_start_factor : float, optional
215+
The factor of start_lr for the initial warmup learning rate.
216+
Default is 0.0.
217+
218+
Raises
219+
------
220+
ValueError
221+
If both stop_lr and stop_ratio are provided, or neither is provided.
222+
If both warmup_steps and warmup_ratio are provided.
223+
If decay_steps is larger than the decay phase total steps.
88224
"""
89-
super().__init__(start_lr, stop_lr, stop_steps, **kwargs)
90-
default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1
225+
super().__init__(
226+
start_lr=start_lr,
227+
stop_lr=stop_lr,
228+
stop_ratio=stop_ratio,
229+
stop_steps=stop_steps,
230+
warmup_steps=warmup_steps,
231+
warmup_ratio=warmup_ratio,
232+
warmup_start_factor=warmup_start_factor,
233+
**kwargs,
234+
)
235+
# === Step 5. Compute decay_rate for exp scheduler ===
236+
# Use decay_stop_steps (stop_steps - warmup_steps) for decay calculation
237+
decay_total = self.decay_stop_steps
91238
self.decay_steps = decay_steps
92-
if self.decay_steps >= stop_steps:
93-
self.decay_steps = default_ds
239+
240+
if self.decay_steps > decay_total:
241+
raise ValueError(
242+
f"decay_steps ({self.decay_steps}) must not exceed decay phase steps ({decay_total})."
243+
)
244+
245+
# Avoid log(0) issues by clamping stop_lr for computation
246+
clamped_stop_lr = max(self.stop_lr, 1e-10)
247+
self.min_lr = self.stop_lr
248+
94249
self.decay_rate = np.exp(
95-
np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps)
250+
np.log(clamped_stop_lr / self.start_lr) / (decay_total / self.decay_steps)
96251
).item()
97252
if decay_rate is not None:
98253
self.decay_rate = decay_rate
99-
self.min_lr = self.stop_lr
100254

101-
def value(self, step: int | Array) -> Array:
102-
"""Get the learning rate at the given step."""
255+
def _decay_value(self, step: int | Array) -> Array:
256+
"""
257+
Get the exponential-decayed learning rate factor at the given step.
258+
259+
Parameters
260+
----------
261+
step : int or Array
262+
The step index relative to the end of warmup.
263+
264+
Returns
265+
-------
266+
Array
267+
The decayed learning rate (absolute value).
268+
"""
103269
if not array_api_compat.is_array_api_obj(step):
104270
step = np.asarray(step)
105271
xp = array_api_compat.array_namespace(step)
106272
step_lr = self.start_lr * xp.pow(
107273
xp.asarray(self.decay_rate, device=array_api_compat.device(step)),
108274
xp.astype(step // self.decay_steps, xp.float64),
109275
)
110-
# the original implementation `if step_lr < self.min_lr:`
111-
# will cause a dynamic graph which is unsupported in JAX JIT
276+
# Clip to min_lr for numerical stability in JIT
112277
step_lr = xp.clip(step_lr, self.min_lr, None)
113278
return step_lr
114279

@@ -118,29 +283,74 @@ class LearningRateCosine(BaseLR):
118283
def __init__(
119284
self,
120285
start_lr: float,
121-
stop_lr: float,
122-
stop_steps: int,
286+
stop_lr: float | None = None,
287+
stop_ratio: float | None = None,
288+
stop_steps: int = 100000,
289+
warmup_steps: int = 0,
290+
warmup_ratio: float | None = None,
291+
warmup_start_factor: float = 0.0,
123292
**kwargs: Any,
124293
) -> None:
125294
"""
126-
Defines a cosine annealing learning rate schedule.
127-
The learning rate starts at `start_lr` and gradually decreases to `stop_lr`
128-
following a cosine curve over the training steps.
295+
Defines a cosine annealing learning rate schedule with optional warmup.
296+
297+
The learning rate starts at `start_lr` (after warmup) and gradually
298+
decreases to `stop_lr` following a cosine curve over the training steps.
129299
130300
Parameters
131301
----------
132-
start_lr
133-
The initial learning rate at the beginning of training.
134-
stop_lr
302+
start_lr : float
303+
The learning rate at the start of the training (after warmup).
304+
stop_lr : float, optional
135305
The final learning rate at the end of training.
136-
stop_steps
137-
The total number of training steps over which the learning rate
138-
will be annealed from start_lr to stop_lr.
306+
Mutually exclusive with stop_ratio.
307+
stop_ratio : float, optional
308+
The ratio of stop_lr to start_lr.
309+
Mutually exclusive with stop_lr.
310+
stop_steps : int
311+
The total training steps (including warmup).
312+
warmup_steps : int, optional
313+
The number of steps for learning rate warmup.
314+
Mutually exclusive with warmup_ratio. Default is 0.
315+
warmup_ratio : float, optional
316+
The ratio of warmup steps to total training steps.
317+
Mutually exclusive with warmup_steps.
318+
warmup_start_factor : float, optional
319+
The factor of start_lr for the initial warmup learning rate.
320+
Default is 0.0.
321+
322+
Raises
323+
------
324+
ValueError
325+
If both stop_lr and stop_ratio are provided, or neither is provided.
326+
If both warmup_steps and warmup_ratio are provided.
139327
"""
140-
super().__init__(start_lr, stop_lr, stop_steps, **kwargs)
141-
self.lr_min_factor = stop_lr / start_lr
328+
super().__init__(
329+
start_lr=start_lr,
330+
stop_lr=stop_lr,
331+
stop_ratio=stop_ratio,
332+
stop_steps=stop_steps,
333+
warmup_steps=warmup_steps,
334+
warmup_ratio=warmup_ratio,
335+
warmup_start_factor=warmup_start_factor,
336+
**kwargs,
337+
)
338+
self.lr_min_factor = self.stop_lr / self.start_lr
142339

143-
def value(self, step: int | Array) -> Array:
340+
def _decay_value(self, step: int | Array) -> Array:
341+
"""
342+
Get the cosine-annealed learning rate at the given step.
343+
344+
Parameters
345+
----------
346+
step : int or Array
347+
The step index relative to the end of warmup.
348+
349+
Returns
350+
-------
351+
Array
352+
The annealed learning rate (absolute value).
353+
"""
144354
if not array_api_compat.is_array_api_obj(step):
145355
step = np.asarray(step)
146356
xp = array_api_compat.array_namespace(step)
@@ -153,11 +363,12 @@ def value(self, step: int | Array) -> Array:
153363
1
154364
+ xp.cos(
155365
xp.asarray(
156-
xp.pi * (xp.astype(step, xp.float64) / self.stop_steps),
366+
xp.pi * (xp.astype(step, xp.float64) / self.decay_stop_steps),
157367
device=array_api_compat.device(step),
158368
)
159369
)
160370
)
161371
)
162-
step_lr = xp.where(step >= self.stop_steps, min_lr, step_lr)
372+
# Clip to min_lr for steps beyond decay_stop_steps
373+
step_lr = xp.where(step >= self.decay_stop_steps, min_lr, step_lr)
163374
return step_lr

0 commit comments

Comments
 (0)