Skip to content

Commit 68e0e7b

Browse files
committed
Fix hardcoded action dim in pi0 pytorch model
1 parent 981483d commit 68e0e7b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/openpi/models_pytorch/pi0_pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,14 @@ def __init__(self, config):
9797
precision=config.dtype,
9898
)
9999

100-
self.action_in_proj = nn.Linear(32, action_expert_config.width)
101-
self.action_out_proj = nn.Linear(action_expert_config.width, 32)
100+
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
101+
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
102102

103103
if self.pi05:
104104
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
105105
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
106106
else:
107-
self.state_proj = nn.Linear(32, action_expert_config.width)
107+
self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)
108108
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
109109
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
110110

0 commit comments

Comments
 (0)