|
1 | 1 | """Register Hydra configs for stable_baselines3 feature extractors.""" |
2 | 2 | import dataclasses |
| 3 | +from enum import Enum |
3 | 4 |
|
| 5 | +import stable_baselines3.common.torch_layers as torch_layers |
4 | 6 | from hydra.core.config_store import ConfigStore |
5 | | -from omegaconf import MISSING |
6 | 7 |
|
7 | 8 |
|
8 | | -@dataclasses.dataclass |
9 | | -class Config: |
10 | | - """Base config for stable_baselines3 feature extractors.""" |
| 9 | +class FeatureExtractorClass(Enum): |
| 10 | + """Enum of feature extractor classes.""" |
11 | 11 |
|
12 | | - _target_: str = MISSING |
| 12 | + FlattenExtractor = torch_layers.FlattenExtractor |
| 13 | + NatureCNN = torch_layers.NatureCNN |
13 | 14 |
|
14 | 15 |
|
15 | 16 | @dataclasses.dataclass |
16 | | -class FlattenExtractorConfig(Config): |
17 | | - """Config for FlattenExtractor.""" |
| 17 | +class Config: |
| 18 | + """Base config for stable_baselines3 feature extractors.""" |
18 | 19 |
|
19 | | - _target_: str = ( |
20 | | - "imitation_cli.utils.feature_extractor_class.FlattenExtractorConfig.make" |
21 | | - ) |
| 20 | + feature_extractor_class: FeatureExtractorClass |
| 21 | + _target_: str = "imitation_cli.utils.feature_extractor_class.Config.make" |
22 | 22 |
|
23 | 23 | @staticmethod |
24 | | - def make() -> type: |
25 | | - import stable_baselines3 |
26 | | - |
27 | | - return stable_baselines3.common.torch_layers.FlattenExtractor |
| 24 | + def make(feature_extractor_class: FeatureExtractorClass) -> type: |
| 25 | + return feature_extractor_class.value |
28 | 26 |
|
29 | 27 |
|
30 | | -@dataclasses.dataclass |
31 | | -class NatureCNNConfig(Config): |
32 | | - """Config for NatureCNN.""" |
33 | | - |
34 | | - _target_: str = "imitation_cli.utils.feature_extractor_class.NatureCNNConfig.make" |
35 | | - |
36 | | - @staticmethod |
37 | | - def make() -> type: |
38 | | - import stable_baselines3 |
39 | | - |
40 | | - return stable_baselines3.common.torch_layers.NatureCNN |
| 28 | +FlattenExtractor = Config(FeatureExtractorClass.FlattenExtractor) |
| 29 | +NatureCNN = Config(FeatureExtractorClass.NatureCNN) |
41 | 30 |
|
42 | 31 |
|
43 | 32 | def register_configs(group: str): |
44 | 33 | cs = ConfigStore.instance() |
45 | | - cs.store(group=group, name="flatten", node=FlattenExtractorConfig) |
46 | | - cs.store(group=group, name="nature_cnn", node=NatureCNNConfig) |
| 34 | + for cls in FeatureExtractorClass: |
| 35 | + cs.store(group=group, name=cls.name.lower(), node=Config(cls)) |
0 commit comments