Skip to content

Commit 2aeb561

Browse files
authored
Support multi-episode runs and max episode steps in runtime (#167)
2 parents 19c6c19 + f148c0b commit 2aeb561

File tree

5 files changed

+35
-17
lines changed

5 files changed

+35
-17
lines changed

examples/aloha_real/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def reset(self) -> None:
2121
self._ts = self._env.reset()
2222

2323
@override
24-
def done(self) -> bool:
24+
def is_episode_complete(self) -> bool:
2525
return False
2626

2727
@override

examples/aloha_real/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ class Args:
1717

1818
action_horizon: int = 25
1919

20+
num_episodes: int = 1
21+
max_episode_steps: int = 1000
22+
2023

2124
def main(args: Args) -> None:
2225
runtime = _runtime.Runtime(
@@ -32,6 +35,8 @@ def main(args: Args) -> None:
3235
),
3336
subscribers=[],
3437
max_hz=50,
38+
num_episodes=args.num_episodes,
39+
max_episode_steps=args.max_episode_steps,
3540
)
3641

3742
runtime.run()

examples/aloha_sim/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def reset(self) -> None:
2727
self._episode_reward = 0.0
2828

2929
@override
30-
def done(self) -> bool:
30+
def is_episode_complete(self) -> bool:
3131
return self._done
3232

3333
@override

packages/openpi-client/src/openpi_client/runtime/environment.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ def reset(self) -> None:
1616
"""
1717

1818
@abc.abstractmethod
19-
def done(self) -> bool:
20-
"""Allow the environment to signal that the task is done.
19+
def is_episode_complete(self) -> bool:
20+
"""Allow the environment to signal that the episode is complete.
2121
22-
This will be called after each step. It should return `True` if the task is
23-
done (either successfully or unsuccessfully), and `False` otherwise.
22+
This will be called after each step. It should return `True` if the episode is
23+
complete (either successfully or unsuccessfully), and `False` otherwise.
2424
"""
2525

2626
@abc.abstractmethod

packages/openpi-client/src/openpi_client/runtime/runtime.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,52 @@ def __init__(
1616
agent: _agent.Agent,
1717
subscribers: list[_subscriber.Subscriber],
1818
max_hz: float = 0,
19+
num_episodes: int = 1,
20+
max_episode_steps: int = 0,
1921
) -> None:
2022
self._environment = environment
2123
self._agent = agent
2224
self._subscribers = subscribers
2325
self._max_hz = max_hz
26+
self._num_episodes = num_episodes
27+
self._max_episode_steps = max_episode_steps
2428

25-
self._running = False
29+
self._in_episode = False
30+
self._episode_steps = 0
2631

2732
def run(self) -> None:
2833
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
29-
self._loop()
34+
for _ in range(self._num_episodes):
35+
self._run_episode()
36+
37+
# Final reset, this is important for real environments to move the robot to its home position.
38+
self._environment.reset()
3039

3140
def run_in_new_thread(self) -> threading.Thread:
3241
"""Runs the runtime loop in a new thread."""
3342
thread = threading.Thread(target=self.run)
3443
thread.start()
3544
return thread
3645

37-
def stop(self) -> None:
38-
"""Stops the runtime loop."""
39-
self._running = False
46+
def mark_episode_complete(self) -> None:
47+
"""Marks the end of an episode."""
48+
self._in_episode = False
4049

41-
def _loop(self) -> None:
42-
"""The runtime loop."""
50+
def _run_episode(self) -> None:
51+
"""Runs a single episode."""
4352
logging.info("Starting episode...")
4453
self._environment.reset()
4554
for subscriber in self._subscribers:
4655
subscriber.on_episode_start()
4756

48-
self._running = True
57+
self._in_episode = True
58+
self._episode_steps = 0
4959
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
5060
last_step_time = time.time()
5161

52-
while self._running:
62+
while self._in_episode:
5363
self._step()
64+
self._episode_steps += 1
5465

5566
# Sleep to maintain the desired frame rate
5667
now = time.time()
@@ -74,5 +85,7 @@ def _step(self) -> None:
7485
for subscriber in self._subscribers:
7586
subscriber.on_step(observation, action)
7687

77-
if self._environment.done():
78-
self.stop()
88+
if self._environment.is_episode_complete() or (
89+
self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
90+
):
91+
self.mark_episode_complete()

0 commit comments

Comments
 (0)