|
1 | 1 | # ruff: noqa |
| 2 | +import fnmatch |
2 | 3 | from importlib.util import find_spec |
3 | | -from typing import Dict, List |
| 4 | +from typing import Dict, List, Optional, Sequence, Set, Union |
4 | 5 |
|
5 | 6 | import torch.cuda |
6 | 7 | from torch import nn |
@@ -416,13 +417,55 @@ def load_lr_scheduler(lr_scheduler: str) -> SCHEDULER: |
416 | 417 | return LR_SCHEDULERS[lr_scheduler] |
417 | 418 |
|
418 | 419 |
|
419 | | -def get_supported_optimizers() -> List[OPTIMIZER]: |
420 | | - return OPTIMIZER_LIST |
| 420 | +def get_supported_optimizers(filters: Optional[Union[str, List[str]]] = None) -> List[str]: |
| 421 | + r"""Return list of available optimizer names, sorted alphabetically. |
421 | 422 |
|
| 423 | + :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will |
| 424 | + return the whole list. |
| 425 | + """ |
| 426 | + if filters is None: |
| 427 | + return sorted(OPTIMIZERS.keys()) |
| 428 | + |
| 429 | + include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] |
| 430 | + |
| 431 | + filtered_list: Set[str] = set() |
| 432 | + for include_filter in include_filters: |
| 433 | + filtered_list.update(fnmatch.filter(OPTIMIZERS.keys(), include_filter)) |
| 434 | + |
| 435 | + return sorted(filtered_list) |
| 436 | + |
| 437 | + |
| 438 | +def get_supported_lr_schedulers(filters: Optional[Union[str, List[str]]] = None) -> List[str]: |
| 439 | + r"""Return list of available lr scheduler names, sorted alphabetically. |
| 440 | +
|
| 441 | + :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will |
| 442 | + return the whole list. |
| 443 | + """ |
| 444 | + if filters is None: |
| 445 | + return sorted(LR_SCHEDULERS.keys()) |
| 446 | + |
| 447 | + include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] |
| 448 | + |
| 449 | + filtered_list: Set[str] = set() |
| 450 | + for include_filter in include_filters: |
| 451 | + filtered_list.update(fnmatch.filter(LR_SCHEDULERS.keys(), include_filter)) |
| 452 | + |
| 453 | + return sorted(filtered_list) |
| 454 | + |
| 455 | + |
| 456 | +def get_supported_loss_functions(filters: Optional[Union[str, List[str]]] = None) -> List[str]: |
| 457 | + r"""Return list of available loss function names, sorted alphabetically. |
| 458 | +
|
| 459 | + :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will |
| 460 | + return the whole list. |
| 461 | + """ |
| 462 | + if filters is None: |
| 463 | + return sorted(LOSS_FUNCTIONS.keys()) |
422 | 464 |
|
423 | | -def get_supported_lr_schedulers() -> List[SCHEDULER]: |
424 | | - return list(LR_SCHEDULER_LIST.values()) |
| 465 | + include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters] |
425 | 466 |
|
| 467 | + filtered_list: Set[str] = set() |
| 468 | + for include_filter in include_filters: |
| 469 | + filtered_list.update(fnmatch.filter(LOSS_FUNCTIONS.keys(), include_filter)) |
426 | 470 |
|
427 | | -def get_supported_loss_functions() -> List[nn.Module]: |
428 | | - return LOSS_FUNCTION_LIST |
| 471 | + return sorted(filtered_list) |
0 commit comments