1515
1616import os
1717
18+ import torch
1819from omegaconf import DictConfig
1920
2021
@@ -43,9 +44,15 @@ def get_model(cfg: DictConfig, torch_dtype=None):
4344 actor_model_config .__dict__ [key ] = val
4445 # load model
4546 checkpoint_dir = download .maybe_download (str (cfg .model_path ))
46- weight_paths = sorted (glob .glob (os .path .join (checkpoint_dir , "*.safetensors" )))
47- if not weight_paths :
48- weight_paths = [os .path .join (checkpoint_dir , "model.safetensors" )]
47+
48+ # Check if this is a checkpoint directory (saved by FSDP)
49+ # Check for model_state_dict/full_weights.pt (direct checkpoint) or actor/model_state_dict/full_weights.pt (from runner)
50+ full_weights_path = os .path .join (
51+ checkpoint_dir , "model_state_dict" , "full_weights.pt"
52+ )
53+ actor_full_weights_path = os .path .join (
54+ checkpoint_dir , "actor" , "model_state_dict" , "full_weights.pt"
55+ )
4956
5057 model : OpenPi0ForRLActionPrediction = OpenPi0ForRLActionPrediction (
5158 actor_model_config
@@ -54,8 +61,23 @@ def get_model(cfg: DictConfig, torch_dtype=None):
5461 if actor_model_config .train_expert_only :
5562 model .freeze_vlm ()
5663
57- for weight_path in weight_paths :
58- safetensors .torch .load_model (model , weight_path , strict = False )
64+ # Load weights from checkpoint if it's a checkpoint directory, otherwise load from safetensors
65+ if os .path .exists (full_weights_path ):
66+ # Direct checkpoint directory
67+ model_state_dict = torch .load (full_weights_path , map_location = "cpu" )
68+ model .load_state_dict (model_state_dict , strict = False )
69+ elif os .path .exists (actor_full_weights_path ):
70+ # Checkpoint directory from runner
71+ model_state_dict = torch .load (actor_full_weights_path , map_location = "cpu" )
72+ model .load_state_dict (model_state_dict , strict = False )
73+ else :
74+ # Original model directory with safetensors files
75+ weight_paths = sorted (glob .glob (os .path .join (checkpoint_dir , "*.safetensors" )))
76+ if not weight_paths :
77+ weight_paths = [os .path .join (checkpoint_dir , "model.safetensors" )]
78+ for weight_path in weight_paths :
79+ safetensors .torch .load_model (model , weight_path , strict = False )
80+
5981 model .paligemma_with_expert .to_bfloat16_for_selected_params ("bfloat16" )
6082 # fsdp replace
6183 # model.paligemma_with_expert.replace_gemma_decoder_layers()
0 commit comments