diff --git a/src/forge/actors/__init__.py b/src/forge/actors/__init__.py index c521b813a..9a1f02029 100644 --- a/src/forge/actors/__init__.py +++ b/src/forge/actors/__init__.py @@ -4,25 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"] +import importlib -def __getattr__(name): - if name == "Policy": - from .policy import Policy - - return Policy - elif name == "PolicyRouter": - from .policy import PolicyRouter +__all__ = [ + "Policy", + "SamplingOverrides", + "WorkerConfig", + "PolicyConfig", + "ReplayBuffer", + "RLTrainer", +] - return PolicyRouter - elif name == "RLTrainer": - from .trainer import RLTrainer - return RLTrainer - elif name == "ReplayBuffer": - from .replay_buffer import ReplayBuffer - - return ReplayBuffer - else: - raise AttributeError(f"module {__name__} has no attribute {name}") +def __getattr__(name): + if name in __all__: + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}")