Skip to content

Commit ecdcb11

Browse files
author
muzhancun
committed
[Online] rich wandb logger
1 parent 97f02c4 commit ecdcb11

File tree

8 files changed

+104
-20
lines changed

8 files changed

+104
-20
lines changed

minestudio/online/rollout/env_worker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,15 @@ def reset_state(self) -> Dict[str, torch.Tensor]:
181181
self.conn.send(("reset_state", None))
182182
return self.conn.recv()
183183

184-
def report_rewards(self, rewards: np.ndarray):
184+
def report_rewards(self, rewards: np.ndarray, task: Optional[str] = None):
185185
"""
186186
Sends the rewards for an episode to the main process.
187187
188188
:param rewards: A NumPy array of rewards for the episode.
189+
:param task: An optional string specifying the task configuration.
189190
:returns: The result from the main process.
190191
"""
191-
self.conn.send(("report_rewards", rewards))
192+
self.conn.send(("report_rewards", rewards, task))
192193
return self.conn.recv()
193194

194195
def run(self) -> None:
@@ -251,7 +252,7 @@ def run(self) -> None:
251252
video_writer.close_video()
252253
#_result = self.report_rewards(np.array(reward_list))
253254

254-
_result, episode_info = self.report_rewards(np.array(reward_list))
255+
_result, episode_info = self.report_rewards(np.array(reward_list), obs.get("task", None))
255256
obs["online_info"] = episode_info
256257

257258
if _result is not None:

minestudio/online/rollout/episode_statistics.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from collections import deque
77
import logging
8-
8+
from rich import print
99
logger = logging.getLogger("ray")
1010
@ray.remote
1111
class EpisodeStatistics:
@@ -47,18 +47,30 @@ def log_statistics(self, step: int, record_next_episode: bool):
4747
num_test_tasks = 0
4848
sum_discounted_reward = 0
4949
sum_episode_length = 0
50+
num_valid_episode_length = 0 # Track valid episode length count
51+
52+
# [Change log] Move the logging of step to the beginning by zhancun
53+
wandb_logger.log({
54+
"episode_statistics/step": step,
55+
})
5056

5157
for task in self.sum_rewards_metrics.keys():
5258
mean_sum_reward = self.sum_rewards_metrics[task].compute()
5359
mean_discounted_reward = self.discounted_rewards_metrics[task].compute()
5460
mean_episode_length = self.episode_lengths_metrics[task].compute()
5561

62+
# Log individual task metrics
63+
if not np.isnan(mean_sum_reward):
64+
wandb_logger.log({
65+
f"episode_statistics/{task}/sum_reward": mean_sum_reward,
66+
f"episode_statistics/{task}/discounted_reward": mean_discounted_reward,
67+
f"episode_statistics/{task}/episode_length": mean_episode_length,
68+
})
69+
print(f"Task {task} - Sum Reward: {mean_sum_reward}, Discounted Reward: {mean_discounted_reward}, Episode Length: {mean_episode_length}")
70+
5671
self.sum_rewards_metrics[task].reset()
5772
self.discounted_rewards_metrics[task].reset()
5873
self.episode_lengths_metrics[task].reset()
59-
wandb_logger.log({
60-
"episode_statistics/step": step,
61-
})
6274

6375
if not np.isnan(mean_sum_reward) and "4train" in task:
6476
sum_train_reward += mean_sum_reward
@@ -67,22 +79,26 @@ def log_statistics(self, step: int, record_next_episode: bool):
6779
if not np.isnan(mean_sum_reward) and "4test" in task:
6880
sum_test_reward += mean_sum_reward
6981
num_test_tasks += 1
70-
sum_episode_length += mean_episode_length
82+
83+
# Only add episode length if it's not NaN
84+
if not np.isnan(mean_episode_length):
85+
sum_episode_length += mean_episode_length
86+
num_valid_episode_length += 1
7187

7288
self.episode_info = {
7389
"steps": step,
7490
"episode_count": self.acc_episode_count,
7591
"mean_sum_reward": sum_train_reward / num_train_tasks if num_train_tasks > 0 else 0,
7692
"mean_discounted_reward": sum_discounted_reward / num_train_tasks if num_train_tasks > 0 else 0,
77-
"mean_episode_length": sum_episode_length / (num_train_tasks + num_test_tasks) if num_train_tasks + num_test_tasks > 0 else 0
93+
"mean_episode_length": sum_episode_length / num_valid_episode_length if num_valid_episode_length > 0 else 0
7894
}
7995
wandb_logger.log({
8096
"episode_statistics/steps": step,
8197
"episode_statistics/episode_count": self.acc_episode_count,
8298
"episode_statistics/mean_sum_reward": sum_train_reward / num_train_tasks if num_train_tasks > 0 else 0,
8399
"episode_statistics/mean_test_sum_reward": sum_test_reward / num_test_tasks if num_test_tasks > 0 else 0,
84100
"episode_statistics/mean_discounted_reward": sum_discounted_reward / num_train_tasks if num_train_tasks > 0 else 0,
85-
"episode_statistics/mean_episode_length": sum_episode_length / (num_train_tasks + num_test_tasks) if num_train_tasks + num_test_tasks > 0 else 0
101+
"episode_statistics/mean_episode_length": sum_episode_length / num_valid_episode_length if num_valid_episode_length > 0 else 0
86102
})
87103

88104
self.acc_episode_count = 0

minestudio/online/rollout/rollout_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,9 @@ def poll_environments(self):
232232
conn.send("ok")
233233
elif args[0] == "report_rewards":
234234
rewards = args[1]
235+
task = args[2] if len(args) > 2 else None
235236
if self.episode_statistics is not None:
236-
video_step, episode_info = ray.get(self.episode_statistics.report_episode.remote(rewards))
237+
video_step, episode_info = ray.get(self.episode_statistics.report_episode.remote(rewards, its_specfg = task if task is not None else ""))
237238
if video_step is not None and video_step > self.video_step:
238239
self.video_step = video_step
239240
conn.send((video_step, episode_info))

minestudio/online/trainer/ppotrainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from minestudio.online.utils import auto_stack
2323
import uuid
2424
import copy
25+
import pickle
2526
import torch.distributed as dist
2627

2728
VERBOSE = False
@@ -595,8 +596,8 @@ def ppo_update(self,
595596
#save model
596597
torch.save(self.inner_model.state_dict(), str(checkpoint_dir / "model.ckpt"))
597598
torch.save(self.optimizer.state_dict(), str(checkpoint_dir / "optimizer.ckpt"))
598-
with open(checkpoint_dir / "whole_config.py", "w") as f:
599-
f.write(self.whole_config)
599+
with open(checkpoint_dir / "whole_config.pkl", "wb") as f:
600+
pickle.dump(self.whole_config, f)
600601

601602
if (
602603
self.last_checkpoint_dir

minestudio/simulator/callbacks/callback.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
'''
22
Date: 2025-01-06 17:32:04
3-
LastEditors: caishaofei-mus1 1744260356@qq.com
4-
LastEditTime: 2025-05-09 14:54:09
3+
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
4+
LastEditTime: 2025-06-15 17:02:18
55
FilePath: /MineStudio/minestudio/simulator/callbacks/callback.py
66
'''
77
import os
@@ -152,7 +152,8 @@ def after_render(self, sim, image):
152152
"""
153153
return image
154154

155-
155+
def __repr__(self):
156+
return f"{self.__class__.__name__}()"
156157
class Compose(MinecraftCallback):
157158
"""
158159
A callback that composes multiple callbacks into a single callback.
@@ -289,4 +290,4 @@ def after_render(self, sim, image):
289290
for callback in self.activate_callbacks:
290291
image = callback.before_render(sim, image)
291292
return image
292-
293+

minestudio/simulator/callbacks/fast_reset.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
'''
22
Date: 2024-11-11 16:15:32
33
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
4-
LastEditTime: 2025-05-26 20:56:38
4+
LastEditTime: 2025-06-12 19:47:39
55
FilePath: /MineStudio/minestudio/simulator/callbacks/fast_reset.py
66
'''
77
import random
88
import numpy as np
99
from minestudio.simulator.callbacks.callback import MinecraftCallback
10+
from minestudio.utils.register import Registers
11+
from rich import print
1012

13+
@Registers.simulator_callback.register
1114
class FastResetCallback(MinecraftCallback):
1215
"""Implements a fast reset mechanism for the Minecraft simulator.
1316
@@ -25,6 +28,28 @@ class FastResetCallback(MinecraftCallback):
2528
:type start_weather: str, optional
2629
"""
2730

31+
def create_from_conf(source):
32+
"""Creates a FastReset from a configuration.
33+
34+
Loads data from the source (file path or dict).
35+
36+
:param source: Configuration source.
37+
:type source: Dict
38+
:returns: FastResetCallback instance or None if no valid configuration is found.
39+
:rtype: Optional[FastResetCallback]
40+
"""
41+
essential_keys = ['biomes', 'random_tp_range']
42+
for key in essential_keys:
43+
if key not in source:
44+
print(f"[red]Missing {key} for FastResetCallback, skipping.[/red]")
45+
return None
46+
return FastResetCallback(
47+
biomes=source['biomes'],
48+
random_tp_range=source['random_tp_range'],
49+
start_time=source.get('start_time', 0),
50+
start_weather=source.get('start_weather', 'clear')
51+
)
52+
2853
def __init__(self, biomes, random_tp_range, start_time=0, start_weather='clear'):
2954
"""Initializes the FastResetCallback.
3055

minestudio/simulator/callbacks/judgereset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1+
'''
2+
Date: 2025-06-12 19:46:03
3+
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
4+
LastEditTime: 2025-06-12 19:47:07
5+
FilePath: /MineStudio/minestudio/simulator/callbacks/judgereset.py
6+
'''
17
from minestudio.simulator.callbacks.callback import MinecraftCallback
28
from minestudio.simulator.utils import MinecraftGUI, GUIConstants
39
from minestudio.simulator.utils.gui import PointDrawCall
10+
from minestudio.utils.register import Registers
11+
from rich import print
412

513
import time
614
from typing import Dict, Literal, Optional, Callable, Tuple
715
import cv2
816

17+
@Registers.simulator_callback.register
918
class JudgeResetCallback(MinecraftCallback):
1019
"""Resets the environment if a time limit is reached or episode terminates.
1120
@@ -17,6 +26,20 @@ class JudgeResetCallback(MinecraftCallback):
1726
Defaults to 600.
1827
:type time_limit: int, optional
1928
"""
29+
30+
def create_from_conf(source: Dict) -> Optional['JudgeResetCallback']:
31+
"""Creates a JudgeResetCallback from a configuration.
32+
33+
:param source: Configuration source.
34+
:type source: Dict
35+
:returns: JudgeResetCallback instance or None if no valid configuration is found.
36+
:rtype: Optional[JudgeResetCallback]
37+
"""
38+
if 'time_limit' not in source:
39+
print("[red]Missing 'time_limit' for JudgeResetCallback, skipping.[/red]")
40+
return None
41+
return JudgeResetCallback(time_limit=source['time_limit'])
42+
2043
def __init__(self, time_limit: int = 600):
2144
"""Initializes the JudgeResetCallback.
2245

minestudio/simulator/callbacks/prev_action.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
'''
22
Date: 2024-11-11 19:31:53
33
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
4-
LastEditTime: 2025-05-26 21:18:13
4+
LastEditTime: 2025-06-12 19:48:23
55
FilePath: /MineStudio/minestudio/simulator/callbacks/prev_action.py
66
'''
77
import os
@@ -16,14 +16,30 @@
1616
"hotbar.4", "hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9",
1717
"camera"]
1818

19-
# @Registers.simulator_callback.register
19+
@Registers.simulator_callback.register
2020
class PrevActionCallback(MinecraftCallback):
2121
"""
2222
A callback that stores the previous action and adds it to the observation.
2323
2424
This callback is useful for tasks where the agent needs to know its previous
2525
action to make a decision.
2626
"""
27+
28+
def create_from_conf(source):
29+
"""Creates a PrevActionCallback from a configuration.
30+
31+
Loads data from the source (file path or dict).
32+
33+
:param source: Configuration source.
34+
:type source: Dict
35+
:returns: PrevActionCallback or None if no valid configuration is found.
36+
:rtype: Optional[PrevActionCallback]
37+
"""
38+
if 'use_prev_action' in source and source['use_prev_action']:
39+
return PrevActionCallback()
40+
else:
41+
print("[red]use_prev_action is not set to True, skipping PrevActionCallback.[/red]")
42+
return None
2743

2844
def __init__(self):
2945
"""

0 commit comments

Comments
 (0)