Skip to content

Commit 05d8ce7

Browse files
committed
Adds rewards and epsilon explanation to console stats
1 parent 912169d commit 05d8ce7

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

examples/clean_rl_pqn_example.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pathlib
77
import random
88
import time
9+
from collections import deque
910
from dataclasses import dataclass
1011

1112
import gymnasium as gym
@@ -181,6 +182,10 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
181182
next_obs = torch.Tensor(next_obs).to(device)
182183
next_done = torch.zeros(args.num_envs).to(device)
183184

185+
# episode reward stats, modified as Godot RL does not return this information in info (yet)
186+
episode_returns = deque(maxlen=20)
187+
accum_rewards = np.zeros(args.num_envs)
188+
184189
for iteration in range(1, args.num_iterations + 1):
185190
# Annealing the rate if instructed to do so.
186191
if args.anneal_lr:
@@ -212,12 +217,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
212217
rewards[step] = torch.tensor(reward).to(device).view(-1)
213218
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
214219

215-
if "final_info" in infos:
216-
for info in infos["final_info"]:
217-
if info and "episode" in info:
218-
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
219-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
220-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
220+
accum_rewards += np.array(reward)
221+
222+
for i, d in enumerate(next_done):
223+
if d:
224+
episode_returns.append(accum_rewards[i])
225+
accum_rewards[i] = 0
221226

222227
# Compute Q(lambda) targets
223228
with torch.no_grad():
@@ -261,8 +266,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
261266

262267
writer.add_scalar("losses/td_loss", loss, global_step)
263268
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
264-
print(f"SPS: {int(global_step / (time.time() - start_time))}, Epsilon: {epsilon}")
269+
print(f"SPS: {int(global_step / (time.time() - start_time))}, Epsilon (rand action prob): {epsilon}")
270+
if len(episode_returns) > 0:
271+
print(
272+
"Returns:",
273+
np.mean(np.array(episode_returns)),
274+
)
265275
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
276+
writer.add_scalar("charts/episodic_return", np.mean(np.array(episode_returns)), global_step)
266277

267278
envs.close()
268279
writer.close()
@@ -298,4 +309,4 @@ def forward(self, onnx_obs, state_ins):
298309
"output": {0: "batch_size"},
299310
"state_outs": {0: "batch_size"},
300311
},
301-
)
312+
)

0 commit comments

Comments
 (0)