forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
21 lines (15 loc) · 646 Bytes
/
train.py
File metadata and controls
21 lines (15 loc) · 646 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import hydra
import torchrl
from torchrl.trainers.algorithms.configs import * # noqa: F401, F403
@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
def print_reward(td):
torchrl.logger.info(f"reward: {td['next', 'reward'].mean(): 4.4f}")
trainer = hydra.utils.instantiate(cfg.trainer)
trainer.register_op(dest="batch_process", op=print_reward)
trainer.train()
if __name__ == "__main__":
main()