Skip to content

Commit 14fbe42

Browse files
committed
Fix:1.fix advantage zero bug
1 parent 4d08bbe commit 14fbe42

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

src/opentau/datasets/lerobot_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ def __getitem__(self, idx) -> dict:
15061506
item = self._to_standard_data_format(item)
15071507

15081508
if self.meta.advantages is not None:
1509-
advantage = self.meta.advantages.get((episode_index, timestamp), 0)
1509+
advantage = self.meta.advantages.get((episode_index, float(timestamp)), 0)
15101510
item["advantage"] = torch.tensor(advantage, dtype=torch.bfloat16)
15111511
else:
15121512
item["advantage"] = torch.tensor(0.0, dtype=torch.bfloat16)

src/opentau/scripts/get_advantage_and_percentiles.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
auto_torch_device,
4545
init_logging,
4646
)
47-
47+
from tqdm import tqdm
4848

4949
def ensure_primitive(maybe_tensor):
5050
if isinstance(maybe_tensor, np.ndarray):
@@ -152,7 +152,7 @@ def main(cfg: TrainPipelineConfig):
152152
ds_advantage = {} # per-dataset advantages
153153
with torch.inference_mode():
154154
# First pass to get the values
155-
for batch in dataloader:
155+
for batch in tqdm(dataloader, desc="Computing values"):
156156
for key, value in batch.items():
157157
if isinstance(value, torch.Tensor):
158158
batch[key] = value.to(device)
@@ -172,15 +172,15 @@ def main(cfg: TrainPipelineConfig):
172172
success=success,
173173
n_steps_look_ahead=cfg.policy.reward_config.N_steps_look_ahead,
174174
episode_end_idx=episode_end_idx,
175-
max_episode_length=cfg.policy.reward_config.reward_normalizer,
175+
reward_normalizer=cfg.policy.reward_config.reward_normalizer,
176176
current_idx=current_idx,
177177
c_neg=cfg.policy.reward_config.C_neg,
178178
)
179179

180-
values[(episode_index, current_idx)] = {"v0": v0, "reward": reward}
180+
values[(episode_index.item(), current_idx.item())] = {"v0": v0, "reward": reward}
181181

182182
# Second pass to compute the advantages
183-
for batch in dataloader:
183+
for batch in tqdm(dataloader, desc="Computing advantages"):
184184
for episode_index, current_idx, timestamp in zip(
185185
batch["episode_index"],
186186
batch["current_idx"],
@@ -192,13 +192,12 @@ def main(cfg: TrainPipelineConfig):
192192
)
193193
# check if the value for the next n_steps_look_ahead steps is available, else set it to 0
194194
look_ahead_idx = current_idx + cfg.policy.reward_config.N_steps_look_ahead
195-
vn = values.get((episode_index, look_ahead_idx), _default0)["v0"]
196-
reward = values.get((episode_index, current_idx), _default0)["reward"]
197-
v0 = values.get((episode_index, current_idx), _default0)["v0"]
195+
vn = values.get((episode_index.item(), look_ahead_idx.item()), _default0)["v0"]
196+
reward = values.get((episode_index.item(), current_idx.item()), _default0)["reward"]
197+
v0 = values.get((episode_index.item(), current_idx.item()), _default0)["v0"]
198198
advantage = ensure_primitive(reward + vn - v0)
199-
advantages.append(advantage)
200-
ds_advantage[(episode_index, timestamp)] = advantage
201-
199+
advantages.append(advantage.item())
200+
ds_advantage[(episode_index.item(), timestamp.item())] = advantage.item()
202201
# Convert tuple keys to strings for JSON serialization
203202
advantage_data_json = {f"{ep_idx},{ts}": val for (ep_idx, ts), val in ds_advantage.items()}
204203

0 commit comments

Comments
 (0)