-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathprecondition_schedules.py
More file actions
199 lines (160 loc) · 7.6 KB
/
precondition_schedules.py
File metadata and controls
199 lines (160 loc) · 7.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from abc import ABC, abstractmethod
from typing import override
__all__ = [
"LinearSchedule",
"CosineSchedule",
"StepSchedule",
]
class PreconditionSchedule(ABC):
"""Base class for precondition frequency schedules.
This class provides a unified interface for creating different types of
precondition frequency schedules. All schedules are callable and take
the current step as input, returning the frequency for that step.
The frequency represents how often to update the preconditioner:
- frequency = 1 means update every step (most frequent)
- frequency = 10 means update every 10 steps (less frequent)
Args:
min_freq: Minimum frequency (most frequent updates)
max_freq: Maximum frequency (least frequent updates)
start_step: Step at which to start applying the schedule (before this, uses min_freq)
"""
def __init__(self, min_freq: int = 1, max_freq: int = 100, start_step: int = 0):
"""Initialize the schedule with frequency bounds."""
if min_freq < 1:
raise ValueError("min_freq must be at least 1")
if max_freq < min_freq:
raise ValueError("max_freq must be >= min_freq")
if start_step < 0:
raise ValueError("start_step must be non-negative")
self.min_freq = min_freq
self.max_freq = max_freq
self.start_step = start_step
def __call__(self, step: int) -> int:
"""Get the frequency for the given step.
Args:
step: Current training step
Returns:
Frequency for the given step, clamped to [min_freq, max_freq]
"""
if step < 0:
raise ValueError("step must be non-negative")
# Before start_step, use min_freq (most frequent updates)
if step < self.start_step:
return self.min_freq
return max(self.min_freq, min(self.max_freq, self._compute_frequency(step)))
@abstractmethod
def _compute_frequency(self, step: int) -> int:
"""Override this method in subclasses to implement the schedule logic.
Args:
step: Current training step
Returns:
Computed frequency (before clamping to bounds)
"""
pass
class LinearSchedule(PreconditionSchedule):
"""Linear transition from frequent to infrequent preconditioning.
This schedule linearly interpolates between min_freq and max_freq over
a specified number of transition steps. After the transition period,
the frequency remains at max_freq.
"""
def __init__(self, min_freq: int = 1, max_freq: int = 100, transition_steps: int = 10000, start_step: int = 0):
"""Initialize linear schedule.
Args:
min_freq: Starting frequency (most frequent updates)
max_freq: Ending frequency (least frequent updates)
transition_steps: Number of steps over which to transition
start_step: Step at which to start applying the schedule (before this, uses min_freq)
"""
super().__init__(min_freq, max_freq, start_step)
if transition_steps <= 0:
raise ValueError("transition_steps must be positive")
self.transition_steps = transition_steps
@override
def _compute_frequency(self, step: int) -> int:
if step <= self.transition_steps:
# Linear interpolation
progress = step / self.transition_steps
return int(self.min_freq + (self.max_freq - self.min_freq) * progress)
else:
return self.max_freq
class CosineSchedule(PreconditionSchedule):
"""Cosine schedule that oscillates between frequencies.
This schedule uses a cosine wave to smoothly transition between min_freq
and max_freq over a specified period. This can be useful for cyclical
training strategies or a single cosine increase.
"""
def __init__(self, min_freq: int = 1, max_freq: int = 50, transition_steps: int = 20000, start_step: int = 0):
"""Initialize cosine schedule.
Args:
min_freq: Minimum frequency in the oscillation
max_freq: Maximum frequency in the oscillation
transition_steps: Number of steps over which to transition
start_step: Step at which to start applying the schedule (before this, uses min_freq)
"""
super().__init__(min_freq, max_freq, start_step)
if transition_steps <= 0:
raise ValueError("transition_steps must be positive")
self.transition_steps = transition_steps
@override
def _compute_frequency(self, step: int) -> int:
progress = (1 + math.cos(math.pi * (step % self.transition_steps) / self.transition_steps)) / 2
current_freq = self.max_freq - (self.max_freq - self.min_freq) * progress
return int(current_freq)
class StepSchedule(PreconditionSchedule):
"""Step-wise schedule with predefined frequency changes at specific steps.
This schedule allows you to specify exact frequencies at specific step
thresholds. The frequency remains constant between thresholds.
Example:
# Different frequencies for different training phases
schedule = StepSchedule({
0: 1, # Update every step for first 1000 steps
1000: 5, # Update every 5 steps from 1000-4999
5000: 10, # Update every 10 steps from 5000-9999
10000: 25 # Update every 25 steps from 10000 onwards
})
"""
def __init__(self, schedule_dict: dict[int, int], start_step: int = 0):
"""Initialize with a dictionary mapping steps to frequencies.
Args:
schedule_dict: Dictionary mapping step thresholds to frequencies
- Keys must be non-negative integers (steps)
- Values must be positive integers (frequencies)
start_step: Step at which to start applying the schedule (before this, uses min_freq)
"""
if not schedule_dict:
raise ValueError("schedule_dict cannot be empty")
# Validate inputs
for step, freq in schedule_dict.items():
if not isinstance(step, int) or step < 0:
raise ValueError("All step thresholds must be non-negative integers")
if not isinstance(freq, int) or freq < 1:
raise ValueError("All frequencies must be positive integers")
self.sorted_steps = sorted(schedule_dict.keys())
self.schedule_dict = schedule_dict
# Set min/max based on the schedule
frequencies = list(schedule_dict.values())
super().__init__(min(frequencies), max(frequencies), start_step)
@override
def _compute_frequency(self, step: int) -> int:
current_freq = self.schedule_dict[self.sorted_steps[0]] # Default to first value
for threshold in self.sorted_steps:
if step >= threshold:
current_freq = self.schedule_dict[threshold]
else:
break
return current_freq