33Used to test internal pi checkpoints and provides utilities to convert them to openpi checkpoints.
44"""
55
6+ from collections .abc import Mapping
67import pathlib
78from typing import Any
89
2021from openpi .shared import normalize as _normalize
2122import openpi .shared .array_typing as at
2223import openpi .shared .download as download
24+ import openpi .transforms as _transforms
2325
2426
2527def convert_to_openpi (
26- ckpt_dir : pathlib .Path | str , processor : str , out_dir : pathlib .Path | str , param_path : str = "decoder"
28+ ckpt_dir : pathlib .Path | str ,
29+ processor : str ,
30+ out_dir : pathlib .Path | str ,
31+ * ,
32+ param_path : str = "decoder" ,
33+ transform : Mapping [str , None ] | None = None ,
2734) -> None :
28- """Convert a monopi checkpoint to an openpi checkpoint."""
35+ """Convert an internal checkpoint to an openpi checkpoint.
36+
37+ Args:
38+ ckpt_dir: The directory containing the internal exported model.
39+ processor: The processor name to use to extract the norm stats.
40+ out_dir: The directory to save the openpi checkpoint.
41+ param_path: The path to the parameters within the overall param structure. Can include "/" to support nesting.
42+ transform: Optional transform patterns to use when converting the checkpoint params. Each key maps from the
43+ original param name to the openpi param name. See `determine_transform_patterns` for more details.
44+ """
2945 out_dir = pathlib .Path (out_dir )
3046 if out_dir .exists ():
3147 raise FileExistsError (f"Output directory already exists: { out_dir } " )
@@ -43,7 +59,9 @@ def convert_to_openpi(
4359 raise ValueError (f"{ part } not found in the checkpoint. Available keys: { list (params )} " )
4460 params = params [part ]
4561
46- # Load the monopi model.
62+ if transform is not None :
63+ params = _transforms .transform_dict (transform , params )
64+
4765 # Save params.
4866 ckpt = ocp .StandardCheckpointer ()
4967 ckpt .save (out_dir / "params" , {"params" : params })
@@ -55,7 +73,7 @@ def convert_to_openpi(
5573
5674@struct .dataclass
5775class PiModel (_model .BaseModel ):
58- """A model loaded from a monopi checkpoint model directory."""
76+ """A model loaded from an internal exported model directory."""
5977
6078 params : at .Params
6179
@@ -66,7 +84,7 @@ class PiModel(_model.BaseModel):
6684
6785 @classmethod
6886 def from_checkpoint (cls , ckpt_dir : pathlib .Path | str ) -> "PiModel" :
69- """Load a model from a monopi model checkpoint directory. Must point at the "model" sub-directory."""
87+ """Load a model from the internal checkpoint directory. Must point at the "model" sub-directory."""
7088 ckpt_dir = download .maybe_download (str (ckpt_dir ))
7189 with (ckpt_dir / "graph" ).open ("rb" ) as f :
7290 exported = jax .export .deserialize (f .read ())
@@ -173,6 +191,59 @@ def set_module(self, module: common.BaseModule, param_path: str) -> _model.Model
173191 )
174192
175193
194+ def determine_transform_patterns (
195+ pi_model : PiModel , module : common .BaseModule , * , param_path : str = "decoder"
196+ ) -> dict [str , str ]:
197+ """Determine the transform patterns to use when converting an internal checkpoint to an openpi checkpoint.
198+
199+ The returned pattern can be used by `transforms.transform_dict` to convert the checkpoint params to the openpi format.
200+ """
201+ model = pi_model .set_module (module , param_path = param_path )
202+
203+ obs , act = model .fake_obs (), model .fake_act ()
204+ real_params = model .init_params (jax .random .key (0 ), obs , act )
205+
206+ real_params = _transforms .flatten_dict (real_params )
207+ loaded_params = _transforms .flatten_dict (model .params )
208+
209+ missing = sorted (set (real_params ) - set (loaded_params ), key = lambda n : (real_params [n ].shape , n ))
210+ extra = sorted (set (loaded_params ) - set (real_params ), key = lambda n : (loaded_params [n ].shape , n ))
211+
212+ if not missing :
213+ return {}
214+
215+ if missing and (len (missing ) == len (extra )):
216+ patterns = dict (zip (extra , missing , strict = True ))
217+ # Confirm that all shapes match.
218+ for k , v in patterns .items ():
219+ if loaded_params [k ].shape != real_params [v ].shape :
220+ print ("Shape mismatch between checkpoint and model candidates:" )
221+ print (k , loaded_params [k ].shape )
222+ print (v , real_params [v ].shape )
223+ print ()
224+ break
225+ else :
226+ return patterns
227+
228+ # Getting here means that there's a mismatch but we were unable
229+
230+ if missing :
231+ print (f"{ len (missing )} missing params in checkpoint:" )
232+ for name in missing :
233+ p = real_params [name ]
234+ print (name , p .shape , str (p .dtype ))
235+ print ()
236+
237+ if extra :
238+ print (f"{ len (extra )} extra params in checkpoint:" )
239+ for name in extra :
240+ p = loaded_params [name ]
241+ print (name , p .shape , str (p .dtype ))
242+ print ()
243+
244+ raise ValueError ("Automatic generation is not possible. Please see the outputs and create the patterns by hand." )
245+
246+
176247def _load_params (
177248 path : pathlib .Path , params_spec : at .PyTree | None = None , sharding : jax .sharding .Sharding | None = None
178249):
0 commit comments