@@ -105,16 +105,21 @@ def get_checkpoint_dir_with_step_num(
105105 checkpoint_root_path : str ,
106106 trainer_type : str = "verl" ,
107107 step_num : Optional [int ] = None ,
108- ) -> str :
108+ ) -> Tuple [ str , int ] :
109109 """Get the checkpoint directory from a root checkpoint directory.
110110
111111 Args:
112112 checkpoint_root_path (str): The root checkpoint directory.
113113 trainer_type (str): The trainer type. Only support "verl" for now.
114- step_num (Optional[int], optional): The step number. Defaults to None.
114+ step_num (Optional[int], optional): The step number. If specified,
115+ load the checkpoint with the specified step number. If None,
116+ load the latest checkpoint. Defaults to None.
117+
118+ Returns:
119+ Tuple[str, int]: The checkpoint directory and the step number of the checkpoint.
115120 """
116121 if trainer_type == "verl" :
117- return get_verl_checkpoint_dir (checkpoint_path = checkpoint_root_path , step_num = step_num )
122+ return get_verl_checkpoint_info (checkpoint_path = checkpoint_root_path , step_num = step_num )
118123 else :
119124 raise NotImplementedError (f"Unsupported trainer type { trainer_type } " )
120125
@@ -144,8 +149,20 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
144149 raise ValueError (f"Unsupported placement: { placement } " )
145150
146151
147- def get_verl_checkpoint_dir (checkpoint_path : str , step_num : Optional [int ] = None ) -> str :
148- """Get the checkpoint directory from a Verl root checkpoint directory."""
152+ def get_verl_checkpoint_info (
153+ checkpoint_path : str , step_num : Optional [int ] = None
154+ ) -> Tuple [str , int ]:
155+ """Get the checkpoint directory from a Verl root checkpoint directory.
156+
157+ Args:
158+ checkpoint_path (str): The root checkpoint directory.
159+ step_num (Optional[int], optional): The step number. If specified,
160+ load the checkpoint with the specified step number. If None,
161+ load the latest checkpoint. Defaults to None.
162+
163+ Returns:
164+ Tuple[str, int]: The checkpoint directory and the step number of the checkpoint.
165+ """
149166 if step_num is None :
150167 # load latest checkpoint
151168 iteration_file = os .path .join (checkpoint_path , "latest_checkpointed_iteration.txt" )
@@ -154,12 +171,12 @@ def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None
154171 iteration_file , "r" , encoding = "utf-8"
155172 ) as f : # TODO: this file may be modified simultaneously
156173 iteration = f .read ().strip ()
157- return os .path .join (checkpoint_path , f"global_step_{ iteration } " )
174+ return os .path .join (checkpoint_path , f"global_step_{ iteration } " ), int ( iteration )
158175 else :
159176 raise FileNotFoundError (f"No iteration file found in { checkpoint_path } " )
160177 else :
161178 # load specific iteration checkpoint
162- return os .path .join (checkpoint_path , f"global_step_{ step_num } " )
179+ return os .path .join (checkpoint_path , f"global_step_{ step_num } " ), step_num
163180
164181
165182# copy from verl/scripts/model_merger.py
0 commit comments