Skip to content

Commit 19c6c19

Browse files
committed
Restore aloha data transforms for exported models
1 parent 45109f9 commit 19c6c19

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ uv run scripts/serve_policy.py --env LIBERO --default_prompt "my task"
104104
This option allows serving a model that was trained using the openpi training code.
105105

106106
```bash
107-
uv run scripts/serve_policy.py --env ALOHA_SIM policy:checkpoint --policy.config=pi0_aloha_sim --policy.dir=checkpoints/pi0_aloha_sim/exp_name/10000
107+
uv run scripts/serve_policy.py --default_prompt "my task" policy:checkpoint --policy.config=pi0_aloha_sim --policy.dir=checkpoints/pi0_aloha_sim/exp_name/10000
108108
```
109109

110110
The training config is used to determine which data transformations should be applied to the runtime data before feeding into the model. The norm stats, which are used to normalize the transformed data, are loaded from the checkpoint directory.
@@ -114,10 +114,10 @@ The training config is used to determine which data transformations should be ap
114114
There are also a number of checkpoints that are available as exported JAX graphs, which we trained ourselves using our internal training code. These can be served using the following command:
115115

116116
```bash
117-
uv run scripts/serve_policy.py --env ALOHA policy:exported --policy.dir=s3://openpi-assets/exported/pi0_aloha/model [--policy.processor=trossen_biarm_single_base_cam_24dim]
117+
uv run scripts/serve_policy.py --env ALOHA policy:exported --policy.dir=s3://openpi-assets/exported/pi0_base/model [--policy.processor=trossen_biarm_single_base_cam_24dim]
118118
```
119119

120-
For these exported models, norm stats are loaded from processors that are exported along with the model, while data transformations are defined in the corresponding default policy (see `create_default_policy` in [scripts/serve_policy.py](scripts/serve_policy.py)). The processor name is optional, and if not provided, we will do the following:
120+
For these exported models, norm stats are loaded from processors that are exported along with the model, while data transformations are defined by the --env argument (see `create_exported_policy` in [scripts/serve_policy.py](scripts/serve_policy.py)). The processor name is optional, and if not provided, we will do the following:
121121
- Load a processor if there is only one available
122122
- Raise an error if there are multiple processors available and ask to provide a processor name
123123

scripts/serve_policy.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from openpi import transforms
1010
from openpi.models import exported as _exported
11+
from openpi.policies import aloha_policy
1112
from openpi.policies import calvin_policy
1213
from openpi.policies import droid_policy
1314
from openpi.policies import libero_policy
@@ -156,6 +157,23 @@ def make_policy_config(
156157

157158
logging.info("Creating policy...")
158159
match env:
160+
case EnvMode.ALOHA:
161+
delta_action_mask = transforms.make_bool_mask(6, -1, 6, -1)
162+
config = make_policy_config(
163+
input_layers=[
164+
aloha_policy.AlohaInputs(action_dim=model.action_dim, adapt_to_pi=True),
165+
transforms.DeltaActions(mask=delta_action_mask),
166+
],
167+
output_layers=[
168+
transforms.AbsoluteActions(mask=delta_action_mask),
169+
aloha_policy.AlohaOutputs(adapt_to_pi=True),
170+
],
171+
)
172+
case EnvMode.ALOHA_SIM:
173+
config = make_policy_config(
174+
input_layers=[aloha_policy.AlohaInputs(action_dim=model.action_dim)],
175+
output_layers=[aloha_policy.AlohaOutputs()],
176+
)
159177
case EnvMode.DROID:
160178
config = make_policy_config(
161179
input_layers=[droid_policy.DroidInputs(action_dim=model.action_dim)],

0 commit comments

Comments
 (0)