Skip to content

Commit ac44a00

Browse files
committed
Utilities to convert from internal checkpoints
1 parent 2aeb561 commit ac44a00

File tree

3 files changed

+159
-6
lines changed

3 files changed

+159
-6
lines changed

src/openpi/models/exported.py

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Used to test internal pi checkpoints and provides utilities to convert them to openpi checkpoints.
44
"""
55

6+
from collections.abc import Mapping
67
import pathlib
78
from typing import Any
89

@@ -20,12 +21,27 @@
2021
from openpi.shared import normalize as _normalize
2122
import openpi.shared.array_typing as at
2223
import openpi.shared.download as download
24+
import openpi.transforms as _transforms
2325

2426

2527
def convert_to_openpi(
26-
ckpt_dir: pathlib.Path | str, processor: str, out_dir: pathlib.Path | str, param_path: str = "decoder"
28+
ckpt_dir: pathlib.Path | str,
29+
processor: str,
30+
out_dir: pathlib.Path | str,
31+
*,
32+
param_path: str = "decoder",
33+
transform: Mapping[str, None] | None = None,
2734
) -> None:
28-
"""Convert a monopi checkpoint to an openpi checkpoint."""
35+
"""Convert an internal checkpoint to an openpi checkpoint.
36+
37+
Args:
38+
ckpt_dir: The directory containing the internal exported model.
39+
processor: The processor name to use to extract the norm stats.
40+
out_dir: The directory to save the openpi checkpoint.
41+
param_path: The path to the parameters within the overall param structure. Can include "/" to support nesting.
42+
transform: Optional transform patterns to use when converting the checkpoint params. Each key maps from the
43+
original param name to the openpi param name. See `determine_transform_patterns` for more details.
44+
"""
2945
out_dir = pathlib.Path(out_dir)
3046
if out_dir.exists():
3147
raise FileExistsError(f"Output directory already exists: {out_dir}")
@@ -43,7 +59,9 @@ def convert_to_openpi(
4359
raise ValueError(f"{part} not found in the checkpoint. Available keys: {list(params)}")
4460
params = params[part]
4561

46-
# Load the monopi model.
62+
if transform is not None:
63+
params = _transforms.transform_dict(transform, params)
64+
4765
# Save params.
4866
ckpt = ocp.StandardCheckpointer()
4967
ckpt.save(out_dir / "params", {"params": params})
@@ -55,7 +73,7 @@ def convert_to_openpi(
5573

5674
@struct.dataclass
5775
class PiModel(_model.BaseModel):
58-
"""A model loaded from a monopi checkpoint model directory."""
76+
"""A model loaded from an internal exported model directory."""
5977

6078
params: at.Params
6179

@@ -66,7 +84,7 @@ class PiModel(_model.BaseModel):
6684

6785
@classmethod
6886
def from_checkpoint(cls, ckpt_dir: pathlib.Path | str) -> "PiModel":
69-
"""Load a model from a monopi model checkpoint directory. Must point at the "model" sub-directory."""
87+
"""Load a model from the internal checkpoint directory. Must point at the "model" sub-directory."""
7088
ckpt_dir = download.maybe_download(str(ckpt_dir))
7189
with (ckpt_dir / "graph").open("rb") as f:
7290
exported = jax.export.deserialize(f.read())
@@ -173,6 +191,59 @@ def set_module(self, module: common.BaseModule, param_path: str) -> _model.Model
173191
)
174192

175193

194+
def determine_transform_patterns(
195+
pi_model: PiModel, module: common.BaseModule, *, param_path: str = "decoder"
196+
) -> dict[str, str]:
197+
"""Determine the transform patterns to use when converting an internal checkpoint to an openpi checkpoint.
198+
199+
The returned pattern can be used by `transforms.transform_dict` to convert the checkpoint params to the openpi format.
200+
"""
201+
model = pi_model.set_module(module, param_path=param_path)
202+
203+
obs, act = model.fake_obs(), model.fake_act()
204+
real_params = model.init_params(jax.random.key(0), obs, act)
205+
206+
real_params = _transforms.flatten_dict(real_params)
207+
loaded_params = _transforms.flatten_dict(model.params)
208+
209+
missing = sorted(set(real_params) - set(loaded_params), key=lambda n: (real_params[n].shape, n))
210+
extra = sorted(set(loaded_params) - set(real_params), key=lambda n: (loaded_params[n].shape, n))
211+
212+
if not missing:
213+
return {}
214+
215+
if missing and (len(missing) == len(extra)):
216+
patterns = dict(zip(extra, missing, strict=True))
217+
# Confirm that all shapes match.
218+
for k, v in patterns.items():
219+
if loaded_params[k].shape != real_params[v].shape:
220+
print("Shape mismatch between checkpoint and model candidates:")
221+
print(k, loaded_params[k].shape)
222+
print(v, real_params[v].shape)
223+
print()
224+
break
225+
else:
226+
return patterns
227+
228+
# Getting here means that there's a mismatch but we were unable
229+
230+
if missing:
231+
print(f"{len(missing)} missing params in checkpoint:")
232+
for name in missing:
233+
p = real_params[name]
234+
print(name, p.shape, str(p.dtype))
235+
print()
236+
237+
if extra:
238+
print(f"{len(extra)} extra params in checkpoint:")
239+
for name in extra:
240+
p = loaded_params[name]
241+
print(name, p.shape, str(p.dtype))
242+
print()
243+
244+
raise ValueError("Automatic generation is not possible. Please see the outputs and create the patterns by hand.")
245+
246+
176247
def _load_params(
177248
path: pathlib.Path, params_spec: at.PyTree | None = None, sharding: jax.sharding.Sharding | None = None
178249
):

src/openpi/transforms.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from collections.abc import Callable, Sequence
1+
from collections.abc import Callable, Mapping, Sequence
22
import dataclasses
3+
import re
34
from typing import Protocol, TypeAlias, TypeVar, runtime_checkable
45

56
import flax.traverse_util as traverse_util
@@ -235,6 +236,60 @@ def unflatten_dict(tree: dict) -> at.PyTree:
235236
return traverse_util.unflatten_dict(tree, sep="/")
236237

237238

239+
def transform_dict(patterns: Mapping[str, str | None], tree: at.PyTree) -> at.PyTree:
240+
"""Transform the structure of a nested dictionary using a set of patterns.
241+
242+
The transformation is defined using the `patterns` dictionary. The keys are the
243+
input keys that should be matched and the values are the new names inside the output
244+
dictionary. If the value is None, the input key is removed.
245+
246+
Both keys and values should represent flattened paths using '/' as the separator.
247+
Keys can be regular expressions and values can include backreferences to the
248+
matched groups (see `re.sub` for more details). Note that the regular expression
249+
must match the entire key.
250+
251+
The order inside the `patterns` dictionary is important. Only the first pattern that
252+
matches the input key will be used.
253+
254+
See unit tests for more examples.
255+
256+
Args:
257+
patterns: A mapping from old keys to new keys.
258+
tree: The nested dictionary to transform.
259+
260+
Returns:
261+
The transformed nested dictionary.
262+
"""
263+
data = flatten_dict(tree)
264+
265+
# Compile the patterns.
266+
compiled = {re.compile(k): v for k, v in patterns.items()}
267+
268+
output = {}
269+
for k in data:
270+
for pattern, repl in compiled.items():
271+
if pattern.fullmatch(k):
272+
new_k = pattern.sub(repl, k, count=1) if repl is not None else None
273+
break
274+
else:
275+
# Use the original key if no match is found.
276+
new_k = k
277+
278+
if new_k is not None:
279+
if new_k in output:
280+
raise ValueError(f"Key '{new_k}' already exists in output")
281+
output[new_k] = data[k]
282+
283+
# Validate the output structure to make sure that it can be unflattened.
284+
names = sorted(output)
285+
for i in range(len(names) - 1):
286+
name, next_name = names[i : i + 2]
287+
if next_name.startswith(name + "/"):
288+
raise ValueError(f"Leaf '{name}' aliases a node of '{next_name}'")
289+
290+
return unflatten_dict(output)
291+
292+
238293
def apply_tree(
239294
tree: at.PyTree[T], selector: at.PyTree[S], fn: Callable[[T, S], T], *, strict: bool = False
240295
) -> at.PyTree[T]:

src/openpi/transforms_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
import openpi.models.tokenizer as _tokenizer
45
import openpi.transforms as _transforms
@@ -86,3 +87,29 @@ def test_tokenize_prompt_default():
8687
tok_prompt, tok_mask = tokenizer.tokenize("This is a default prompt")
8788
assert np.allclose(tok_prompt, data["tokenized_prompt"])
8889
assert np.allclose(tok_mask, data["tokenized_prompt_mask"])
90+
91+
92+
def test_transform_dict():
93+
# Rename and remove keys.
94+
input = {"a": {"b": 1, "c": 2}}
95+
output = _transforms.transform_dict({"a/b": "a/c", "a/c": None}, input)
96+
assert output == {"a": {"c": 1}}
97+
98+
# Raises and error since the renamed key conflicts with an existing key.
99+
with pytest.raises(ValueError, match="Key 'a/c' already exists in output"):
100+
_transforms.transform_dict({"a/b": "a/c"}, input)
101+
102+
# Full match is required and so nothing will be removed.
103+
input = {"a": {"b": 1, "c": 2}}
104+
output = _transforms.transform_dict({"a": None}, input)
105+
assert output == input
106+
107+
# The regex matches the entire key and so the entire input will be removed.
108+
input = {"a": {"b": 1, "c": 2}}
109+
output = _transforms.transform_dict({"a.+": None}, input)
110+
assert output == {}
111+
112+
# Replace keys using backreferences. All leaves named 'c' are replaced with 'd'.
113+
input = {"a": {"b": 1, "c": 1}, "b": {"c": 2}}
114+
output = _transforms.transform_dict({"(.+)/c": r"\1/d"}, input)
115+
assert output == {"a": {"b": 1, "d": 1}, "b": {"d": 2}}

0 commit comments

Comments
 (0)