Skip to content

Commit 8c96f97

Browse files
committed
update: betas
1 parent 0179cbf commit 8c96f97

File tree

3 files changed

+16
-13
lines changed

3 files changed

+16
-13
lines changed

pytorch_optimizer/base/optimizer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def get_adanorm_gradient(
215215
return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad
216216

217217
@staticmethod
218-
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)'):
218+
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
219219
if range_type == '[)' and not low <= x < high:
220220
raise ValueError(f'[-] {name} must be in the range [{low}, {high})')
221221
if range_type == '[]' and not low <= x <= high:
@@ -226,40 +226,42 @@ def validate_range(x: float, name: str, low: float, high: float, range_type: str
226226
raise ValueError(f'[-] {name} must be in the range ({low}, {high})')
227227

228228
@staticmethod
229-
def validate_non_negative(x: Optional[float], name: str):
229+
def validate_non_negative(x: Optional[float], name: str) -> None:
230230
if x is not None and x < 0.0:
231231
raise ValueError(f'[-] {name} must be non-negative')
232232

233233
@staticmethod
234-
def validate_positive(x: Union[float, int], name: str):
234+
def validate_positive(x: Union[float, int], name: str) -> None:
235235
if x <= 0:
236236
raise ValueError(f'[-] {name} must be positive')
237237

238238
@staticmethod
239-
def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper'):
239+
def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper') -> None:
240240
if bound_type == 'upper' and constant > boundary:
241241
raise ValueError(f'[-] constant {constant} must be in a range of (-inf, {boundary}]')
242242
if bound_type == 'lower' and constant < boundary:
243243
raise ValueError(f'[-] constant {constant} must be in a range of [{boundary}, inf)')
244244

245245
@staticmethod
246-
def validate_step(step: int, step_type: str):
246+
def validate_step(step: int, step_type: str) -> None:
247247
if step < 1:
248248
raise NegativeStepError(step, step_type=step_type)
249249

250250
@staticmethod
251-
def validate_options(x: str, name: str, options: List[str]):
251+
def validate_options(x: str, name: str, options: List[str]) -> None:
252252
if x not in options:
253253
opts: str = ' or '.join([f'\'{option}\'' for option in options]).strip()
254254
raise ValueError(f'[-] {name} {x} must be one of ({opts})')
255255

256256
@staticmethod
257-
def validate_learning_rate(learning_rate: Optional[float]):
257+
def validate_learning_rate(learning_rate: Optional[float]) -> None:
258258
if learning_rate is not None and learning_rate < 0.0:
259259
raise NegativeLRError(learning_rate)
260260

261-
def validate_betas(self, betas: BETAS):
262-
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')
261+
def validate_betas(self, betas: BETAS) -> None:
262+
if betas[0] is not None:
263+
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')
264+
263265
self.validate_range(betas[1], 'beta2', 0.0, 1.0, range_type='[]')
264266

265267
if len(betas) < 3:
@@ -268,7 +270,7 @@ def validate_betas(self, betas: BETAS):
268270
if betas[2] is not None:
269271
self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type='[]')
270272

271-
def validate_nus(self, nus: Union[float, Tuple[float, float]]):
273+
def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
272274
if isinstance(nus, float):
273275
self.validate_range(nus, 'nu', 0.0, 1.0, range_type='[]')
274276
else:

pytorch_optimizer/base/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
CLOSURE = Optional[Callable[[], float]]
88
LOSS = Optional[float]
9-
BETAS = Union[Tuple[float, float], Tuple[float, float, float]]
9+
BETAS = Union[Tuple[float, float], Tuple[float, float, float], Tuple[None, float]]
1010
DEFAULTS = Dict
1111
PARAMETERS = Optional[Union[Iterable[Dict], Iterable[torch.Tensor]]]
1212
STATE = Dict

tests/constants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,9 @@
348348
(DAdaptLion, {'lr': 3e0, 'weight_decay': 1e-3}, 20),
349349
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
350350
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20),
351-
(AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),
352-
(AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 125),
351+
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),
352+
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'ams_bound': True}, 120),
353+
(AdaFactor, {'lr': 1e1, 'betas': (None, 0.999), 'weight_decay': 1e-3}, 40),
353354
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
354355
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10),
355356
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decay_type': 'stable', 'warmup_steps': 0}, 50),

0 commit comments

Comments
 (0)