Skip to content

Commit 9c0e88b

Browse files
authored
fix: add sft ckpt loader (RLinf#796)
Signed-off-by: Florielle <1205402283@qq.com>
1 parent 605bccf commit 9c0e88b

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

rlinf/models/embodiment/openpi/__init__.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import os
1717

18+
import torch
1819
from omegaconf import DictConfig
1920

2021

@@ -43,9 +44,15 @@ def get_model(cfg: DictConfig, torch_dtype=None):
4344
actor_model_config.__dict__[key] = val
4445
# load model
4546
checkpoint_dir = download.maybe_download(str(cfg.model_path))
46-
weight_paths = sorted(glob.glob(os.path.join(checkpoint_dir, "*.safetensors")))
47-
if not weight_paths:
48-
weight_paths = [os.path.join(checkpoint_dir, "model.safetensors")]
47+
48+
# Check if this is a checkpoint directory (saved by FSDP)
49+
# Check for model_state_dict/full_weights.pt (direct checkpoint) or actor/model_state_dict/full_weights.pt (from runner)
50+
full_weights_path = os.path.join(
51+
checkpoint_dir, "model_state_dict", "full_weights.pt"
52+
)
53+
actor_full_weights_path = os.path.join(
54+
checkpoint_dir, "actor", "model_state_dict", "full_weights.pt"
55+
)
4956

5057
model: OpenPi0ForRLActionPrediction = OpenPi0ForRLActionPrediction(
5158
actor_model_config
@@ -54,8 +61,23 @@ def get_model(cfg: DictConfig, torch_dtype=None):
5461
if actor_model_config.train_expert_only:
5562
model.freeze_vlm()
5663

57-
for weight_path in weight_paths:
58-
safetensors.torch.load_model(model, weight_path, strict=False)
64+
# Load weights from checkpoint if it's a checkpoint directory, otherwise load from safetensors
65+
if os.path.exists(full_weights_path):
66+
# Direct checkpoint directory
67+
model_state_dict = torch.load(full_weights_path, map_location="cpu")
68+
model.load_state_dict(model_state_dict, strict=False)
69+
elif os.path.exists(actor_full_weights_path):
70+
# Checkpoint directory from runner
71+
model_state_dict = torch.load(actor_full_weights_path, map_location="cpu")
72+
model.load_state_dict(model_state_dict, strict=False)
73+
else:
74+
# Original model directory with safetensors files
75+
weight_paths = sorted(glob.glob(os.path.join(checkpoint_dir, "*.safetensors")))
76+
if not weight_paths:
77+
weight_paths = [os.path.join(checkpoint_dir, "model.safetensors")]
78+
for weight_path in weight_paths:
79+
safetensors.torch.load_model(model, weight_path, strict=False)
80+
5981
model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
6082
# fsdp replace
6183
# model.paligemma_with_expert.replace_gemma_decoder_layers()

0 commit comments

Comments
 (0)