Skip to content

Commit 1c3f681

Browse files
committed
feature: support filters
1 parent 8c6b343 commit 1c3f681

File tree

1 file changed

+50
-7
lines changed

1 file changed

+50
-7
lines changed

pytorch_optimizer/__init__.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# ruff: noqa
2+
import fnmatch
23
from importlib.util import find_spec
3-
from typing import Dict, List
4+
from typing import Dict, List, Optional, Sequence, Set, Union
45

56
import torch.cuda
67
from torch import nn
@@ -416,13 +417,55 @@ def load_lr_scheduler(lr_scheduler: str) -> SCHEDULER:
416417
return LR_SCHEDULERS[lr_scheduler]
417418

418419

419-
def get_supported_optimizers() -> List[OPTIMIZER]:
420-
return OPTIMIZER_LIST
420+
def get_supported_optimizers(filters: Optional[Union[str, List[str]]] = None) -> List[str]:
421+
r"""Return list of available optimizer names, sorted alphabetically.
421422
423+
:param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will
424+
return the whole list.
425+
"""
426+
if filters is None:
427+
return sorted(OPTIMIZERS.keys())
428+
429+
include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters]
430+
431+
filtered_list: Set[str] = set()
432+
for include_filter in include_filters:
433+
filtered_list.update(fnmatch.filter(OPTIMIZERS.keys(), include_filter))
434+
435+
return sorted(filtered_list)
436+
437+
438+
def get_supported_lr_schedulers(filters: Optional[Union[str, List[str]]] = None) -> List[str]:
439+
r"""Return list of available lr scheduler names, sorted alphabetically.
440+
441+
:param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will
442+
return the whole list.
443+
"""
444+
if filters is None:
445+
return sorted(LR_SCHEDULERS.keys())
446+
447+
include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters]
448+
449+
filtered_list: Set[str] = set()
450+
for include_filter in include_filters:
451+
filtered_list.update(fnmatch.filter(LR_SCHEDULERS.keys(), include_filter))
452+
453+
return sorted(filtered_list)
454+
455+
456+
def get_supported_loss_functions(filters: Optional[Union[str, List[str]]] = None) -> List[str]:
457+
r"""Return list of available loss function names, sorted alphabetically.
458+
459+
:param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will
460+
return the whole list.
461+
"""
462+
if filters is None:
463+
return sorted(LOSS_FUNCTIONS.keys())
422464

423-
def get_supported_lr_schedulers() -> List[SCHEDULER]:
424-
return list(LR_SCHEDULER_LIST.values())
465+
include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters]
425466

467+
filtered_list: Set[str] = set()
468+
for include_filter in include_filters:
469+
filtered_list.update(fnmatch.filter(LOSS_FUNCTIONS.keys(), include_filter))
426470

427-
def get_supported_loss_functions() -> List[nn.Module]:
428-
return LOSS_FUNCTION_LIST
471+
return sorted(filtered_list)

0 commit comments

Comments
 (0)