|
11 | 11 | """
|
12 | 12 | from __future__ import annotations
|
13 | 13 |
|
14 |
| -import time |
15 | 14 | import warnings
|
16 | 15 |
|
17 | 16 | import hydra
|
|
21 | 20 | import tqdm
|
22 | 21 | from tensordict.nn import CudaGraphModule
|
23 | 22 |
|
24 |
| -from torchrl._utils import logger as torchrl_logger, timeit |
| 23 | +from torchrl._utils import timeit |
25 | 24 | from torchrl.envs.utils import ExplorationType, set_exploration_type
|
26 | 25 | from torchrl.objectives import group_optimizers
|
27 | 26 | from torchrl.record.loggers import generate_exp_name, get_logger
|
@@ -156,9 +155,9 @@ def update(data, policy_eval_start, iteration):
|
156 | 155 | eval_steps = cfg.logger.eval_steps
|
157 | 156 |
|
158 | 157 | # Training loop
|
159 |
| - start_time = time.time() |
160 | 158 | policy_eval_start = torch.tensor(policy_eval_start, device=device)
|
161 | 159 | for i in range(gradient_steps):
|
| 160 | + timeit.printevery(1000, gradient_steps, erase=True) |
162 | 161 | pbar.update(1)
|
163 | 162 | # sample data
|
164 | 163 | with timeit("sample"):
|
@@ -192,15 +191,10 @@ def update(data, policy_eval_start, iteration):
|
192 | 191 | to_log["evaluation_reward"] = eval_reward
|
193 | 192 |
|
194 | 193 | with timeit("log"):
|
195 |
| - if i % 200 == 0: |
196 |
| - to_log.update(timeit.todict(prefix="time")) |
| 194 | + to_log.update(timeit.todict(prefix="time")) |
197 | 195 | log_metrics(logger, to_log, i)
|
198 |
| - if i % 200 == 0: |
199 |
| - timeit.print() |
200 |
| - timeit.erase() |
201 | 196 |
|
202 | 197 | pbar.close()
|
203 |
| - torchrl_logger.info(f"Training time: {time.time() - start_time}") |
204 | 198 | if not eval_env.is_closed:
|
205 | 199 | eval_env.close()
|
206 | 200 |
|
|
0 commit comments