Skip to content

Commit ae4d594

Browse files
authored
Adding warning when discrete action is truncated and increased dicret… (#119)
1 parent cd8c70d commit ae4d594

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

configs/examples/pi05_training_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"freeze_vision_encoder": true,
3838
"train_expert_only": true,
3939
"prompt_max_length": 256,
40-
"discrete_action_max_length": 60,
40+
"discrete_action_max_length": 100,
4141
"optimizer_lr": 2.5e-05,
4242
"optimizer_betas": [
4343
0.9,

src/opentau/policies/pi05/modeling_pi05.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ def pad_discrete_tokens(tokens: list[list[int]], max_length: int) -> tuple[np.nd
215215
discrete_action_masks = []
216216
for token in tokens:
217217
if len(token) > max_length:
218+
logging.warning(
219+
f"Discrete action token length {len(token)} is greater than max_length {max_length}, truncating"
220+
)
218221
discrete_action_tokens.append(np.array(token[:max_length]))
219222
discrete_action_masks.append(np.ones(max_length, dtype=bool))
220223
else:

0 commit comments

Comments
 (0)