Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 76 additions & 5 deletions src/openpi/models/exported.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Used to test internal pi checkpoints and provides utilities to convert them to openpi checkpoints.
"""

from collections.abc import Mapping
import pathlib
from typing import Any

Expand All @@ -20,12 +21,27 @@
from openpi.shared import normalize as _normalize
import openpi.shared.array_typing as at
import openpi.shared.download as download
import openpi.transforms as _transforms


def convert_to_openpi(
ckpt_dir: pathlib.Path | str, processor: str, out_dir: pathlib.Path | str, param_path: str = "decoder"
ckpt_dir: pathlib.Path | str,
processor: str,
out_dir: pathlib.Path | str,
*,
param_path: str = "decoder",
transform: Mapping[str, None] | None = None,
) -> None:
"""Convert a monopi checkpoint to an openpi checkpoint."""
"""Convert an internal checkpoint to an openpi checkpoint.

Args:
ckpt_dir: The directory containing the internal exported model.
processor: The processor name to use to extract the norm stats.
out_dir: The directory to save the openpi checkpoint.
param_path: The path to the parameters within the overall param structure. Can include "/" to support nesting.
transform: Optional transform patterns to use when converting the checkpoint params. Each key maps from the
original param name to the openpi param name. See `determine_transform_patterns` for more details.
"""
out_dir = pathlib.Path(out_dir)
if out_dir.exists():
raise FileExistsError(f"Output directory already exists: {out_dir}")
Expand All @@ -43,7 +59,9 @@ def convert_to_openpi(
raise ValueError(f"{part} not found in the checkpoint. Available keys: {list(params)}")
params = params[part]

# Load the monopi model.
if transform is not None:
params = _transforms.transform_dict(transform, params)

# Save params.
ckpt = ocp.StandardCheckpointer()
ckpt.save(out_dir / "params", {"params": params})
Expand All @@ -55,7 +73,7 @@ def convert_to_openpi(

@struct.dataclass
class PiModel(_model.BaseModel):
"""A model loaded from a monopi checkpoint model directory."""
"""A model loaded from an internal exported model directory."""

params: at.Params

Expand All @@ -66,7 +84,7 @@ class PiModel(_model.BaseModel):

@classmethod
def from_checkpoint(cls, ckpt_dir: pathlib.Path | str) -> "PiModel":
"""Load a model from a monopi model checkpoint directory. Must point at the "model" sub-directory."""
"""Load a model from the internal checkpoint directory. Must point at the "model" sub-directory."""
ckpt_dir = download.maybe_download(str(ckpt_dir))
with (ckpt_dir / "graph").open("rb") as f:
exported = jax.export.deserialize(f.read())
Expand Down Expand Up @@ -173,6 +191,59 @@ def set_module(self, module: common.BaseModule, param_path: str) -> _model.Model
)


def determine_transform_patterns(
pi_model: PiModel, module: common.BaseModule, *, param_path: str = "decoder"
) -> dict[str, str]:
"""Determine the transform patterns to use when converting an internal checkpoint to an openpi checkpoint.

The returned pattern can be used by `transforms.transform_dict` to convert the checkpoint params to the openpi format.
"""
model = pi_model.set_module(module, param_path=param_path)

obs, act = model.fake_obs(), model.fake_act()
real_params = model.init_params(jax.random.key(0), obs, act)

real_params = _transforms.flatten_dict(real_params)
loaded_params = _transforms.flatten_dict(model.params)

missing = sorted(set(real_params) - set(loaded_params), key=lambda n: (real_params[n].shape, n))
extra = sorted(set(loaded_params) - set(real_params), key=lambda n: (loaded_params[n].shape, n))

if not missing:
return {}

if missing and (len(missing) == len(extra)):
patterns = dict(zip(extra, missing, strict=True))
# Confirm that all shapes match.
for k, v in patterns.items():
if loaded_params[k].shape != real_params[v].shape:
print("Shape mismatch between checkpoint and model candidates:")
print(k, loaded_params[k].shape)
print(v, real_params[v].shape)
print()
break
else:
return patterns

# Getting here means that there's a mismatch but we were unable

if missing:
print(f"{len(missing)} missing params in checkpoint:")
for name in missing:
p = real_params[name]
print(name, p.shape, str(p.dtype))
print()

if extra:
print(f"{len(extra)} extra params in checkpoint:")
for name in extra:
p = loaded_params[name]
print(name, p.shape, str(p.dtype))
print()

raise ValueError("Automatic generation is not possible. Please see the outputs and create the patterns by hand.")


def _load_params(
path: pathlib.Path, params_spec: at.PyTree | None = None, sharding: jax.sharding.Sharding | None = None
):
Expand Down
57 changes: 56 additions & 1 deletion src/openpi/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable, Sequence
from collections.abc import Callable, Mapping, Sequence
import dataclasses
import re
from typing import Protocol, TypeAlias, TypeVar, runtime_checkable

import flax.traverse_util as traverse_util
Expand Down Expand Up @@ -235,6 +236,60 @@ def unflatten_dict(tree: dict) -> at.PyTree:
return traverse_util.unflatten_dict(tree, sep="/")


def transform_dict(patterns: Mapping[str, str | None], tree: at.PyTree) -> at.PyTree:
"""Transform the structure of a nested dictionary using a set of patterns.

The transformation is defined using the `patterns` dictionary. The keys are the
input keys that should be matched and the values are the new names inside the output
dictionary. If the value is None, the input key is removed.

Both keys and values should represent flattened paths using '/' as the separator.
Keys can be regular expressions and values can include backreferences to the
matched groups (see `re.sub` for more details). Note that the regular expression
must match the entire key.

The order inside the `patterns` dictionary is important. Only the first pattern that
matches the input key will be used.

See unit tests for more examples.

Args:
patterns: A mapping from old keys to new keys.
tree: The nested dictionary to transform.

Returns:
The transformed nested dictionary.
"""
data = flatten_dict(tree)

# Compile the patterns.
compiled = {re.compile(k): v for k, v in patterns.items()}

output = {}
for k in data:
for pattern, repl in compiled.items():
if pattern.fullmatch(k):
new_k = pattern.sub(repl, k, count=1) if repl is not None else None
break
else:
# Use the original key if no match is found.
new_k = k

if new_k is not None:
if new_k in output:
raise ValueError(f"Key '{new_k}' already exists in output")
output[new_k] = data[k]

# Validate the output structure to make sure that it can be unflattened.
names = sorted(output)
for i in range(len(names) - 1):
name, next_name = names[i : i + 2]
if next_name.startswith(name + "/"):
raise ValueError(f"Leaf '{name}' aliases a node of '{next_name}'")

return unflatten_dict(output)


def apply_tree(
tree: at.PyTree[T], selector: at.PyTree[S], fn: Callable[[T, S], T], *, strict: bool = False
) -> at.PyTree[T]:
Expand Down
27 changes: 27 additions & 0 deletions src/openpi/transforms_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest

import openpi.models.tokenizer as _tokenizer
import openpi.transforms as _transforms
Expand Down Expand Up @@ -86,3 +87,29 @@ def test_tokenize_prompt_default():
tok_prompt, tok_mask = tokenizer.tokenize("This is a default prompt")
assert np.allclose(tok_prompt, data["tokenized_prompt"])
assert np.allclose(tok_mask, data["tokenized_prompt_mask"])


def test_transform_dict():
# Rename and remove keys.
input = {"a": {"b": 1, "c": 2}}
output = _transforms.transform_dict({"a/b": "a/c", "a/c": None}, input)
assert output == {"a": {"c": 1}}

# Raises and error since the renamed key conflicts with an existing key.
with pytest.raises(ValueError, match="Key 'a/c' already exists in output"):
_transforms.transform_dict({"a/b": "a/c"}, input)

# Full match is required and so nothing will be removed.
input = {"a": {"b": 1, "c": 2}}
output = _transforms.transform_dict({"a": None}, input)
assert output == input

# The regex matches the entire key and so the entire input will be removed.
input = {"a": {"b": 1, "c": 2}}
output = _transforms.transform_dict({"a.+": None}, input)
assert output == {}

# Replace keys using backreferences. All leaves named 'c' are replaced with 'd'.
input = {"a": {"b": 1, "c": 1}, "b": {"c": 2}}
output = _transforms.transform_dict({"(.+)/c": r"\1/d"}, input)
assert output == {"a": {"b": 1, "d": 1}, "b": {"d": 2}}