From 865b8aef66fbfbc881b3a1c738a0cbbc5982aec9 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Mon, 1 Dec 2025 02:49:35 +0000 Subject: [PATCH 1/5] feat: implement model weight conversion from linen to nnx --- tools/weight_inspector/compare_checkpoint.py | 150 ++++++++++++++++++ .../convert_linen_checkpoint_to_nnx.py | 125 +++++++++++++++ tools/weight_inspector/weight_inspector.py | 2 - 3 files changed, 275 insertions(+), 2 deletions(-) create mode 100644 tools/weight_inspector/compare_checkpoint.py create mode 100644 tools/weight_inspector/convert_linen_checkpoint_to_nnx.py diff --git a/tools/weight_inspector/compare_checkpoint.py b/tools/weight_inspector/compare_checkpoint.py new file mode 100644 index 000000000..5f2be8c47 --- /dev/null +++ b/tools/weight_inspector/compare_checkpoint.py @@ -0,0 +1,150 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +r"""This is to inspect/analyze two checkpoint weights with the same structure to find differences. + + +Usage: + +python3 -m tools/weight_inspector/compare_checkpoint.py --lhs /model-left/runner_direct_1/checkpoints/0/items --rhs /model-right/runner_direct_1/checkpoints/0/items + +""" + +import argparse +import jax +import orbax.checkpoint as ocp +from typing import Any, Dict, Set +import pprint +import numpy as np + +def load_params_from_path(checkpoint_dir: str) -> Dict[str, Any] | None: + + if not checkpoint_dir: + raise ValueError("checkpoint_dir must be provided.") + print(f"Loading quantized params checkpoint from: {checkpoint_dir}") + orbax_checkpointer = ocp.PyTreeCheckpointer() + try: + restored_object = orbax_checkpointer.restore(checkpoint_dir) + if "params" in restored_object: + print(f"Successfully restored checkpoint from {checkpoint_dir}") + return restored_object["params"] + else: + print(f"Error: 'params' key not found in the restored checkpoint at {checkpoint_dir}") + return None + except Exception as e: + print(f"An error occurred during checkpoint restoration from {checkpoint_dir}: {e}") + return None + +def get_tree_paths(tree: Any) -> Set[str]: + flat_with_path, _ = jax.tree_util.tree_flatten_with_path(tree) + return {jax.tree_util.keystr(p) for p, _ in flat_with_path} + +def compare_quantized_checkpoints(left_path: str, right_path: str, rtol: float = 1e-3, atol: float = 1e-3) -> bool: + print(f"\n--- Comparing Checkpoints ---") + print(f" Left checkpoint path: {left_path}") + print(f" Right checkpoint path: {right_path}") + + params_left = load_params_from_path(left_path) + params_right = load_params_from_path(right_path) + + if params_left is None or params_right is None: + print("❌ Loading failed for one or both checkpoints. Cannot compare.") + return False + + flat_left, struct1 = jax.tree_util.tree_flatten_with_path(params_left) + flat_right, struct2 = jax.tree_util.tree_flatten_with_path(params_right) + + if struct1 != struct2: + print("❌ Tree structures differ.") + paths1 = get_tree_paths(params_left) + paths2 = get_tree_paths(params_right) + in_left_only = sorted(list(paths1 - paths2)) + if in_left_only: + print("\n Paths unique to left checkpoint:") + for p in in_left_only: print(f" {p}") + in_right_only = sorted(list(paths2 - paths1)) + if in_right_only: + print("\n Paths unique to right checkpoint:") + for p in in_right_only: print(f" {p}") + return False + + print("✅ Tree structures are identical.") + + map_left = {jax.tree_util.keystr(p): v for p, v in flat_left} + map_right = {jax.tree_util.keystr(p): v for p, v in flat_right} + + all_equal = True + print("\n--- Comparing Leaf Values ---") + for key in sorted(map_left.keys()): + left_values = map_left[key] + right_values = map_right[key] + + if type(left_values) is not type(right_values): + print(f"❌ Type mismatch at {key}: {type(left_values)} vs {type(right_values)}") + all_equal = False + continue + + if isinstance(left_values, jax.Array): + if left_values.shape != right_values.shape: + print(f"❌ Shape mismatch at {key}: {left_values.shape} vs {right_values.shape}") + all_equal = False; continue + if left_values.dtype != right_values.dtype: + print(f"❌ Dtype mismatch at {key}: {left_values.dtype} vs {right_values.dtype}") + all_equal = False; continue + + try: + left_cpu = jax.device_get(left_values) + right_cpu = jax.device_get(right_values) + except Exception as e: + print(f"❌ Error during jax.device_get at {key}: {e}") + all_equal = False; continue + + if not np.allclose(left_cpu, right_cpu, rtol=rtol, atol=atol): + print(f"❌ Numerical difference in JAX Array at {key}.") + diff = np.abs(left_cpu - right_cpu) + print(f" Max diff: {np.max(diff)}, Mean diff: {np.mean(diff)}") + all_equal = False + elif isinstance(left_values, dict): + if left_values != right_values: + print(f"❌ Dict difference at {key}:") + pprint.pprint(f" Left: {left_values}", width=120) + pprint.pprint(f" Right: {right_values}", width=120) + all_equal = False + elif left_values != right_values: + try: + if np.isscalar(left_values) and np.isscalar(right_values) and np.allclose(np.array(left_values), np.array(right_values), rtol=rtol, atol=atol): + continue + except (TypeError, ValueError): + pass + print(f"❌ Value difference at {key}: {left_values} vs {right_values}") + all_equal = False + + if all_equal: + print("\n✅ All compared leaf values are identical or numerically close.") + else: + print("\n❌ Differences found in leaf values. See details above.") + return all_equal + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--lhs", type=str, required=True) + parser.add_argument("--rhs", type=str, required=True) + + args = parser.parse_args() + are_checkpoints_same = compare_quantized_checkpoints(args.lhs, args.rhg) + print(f"\nComparison result: {are_checkpoints_same}") + + + diff --git a/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py b/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py new file mode 100644 index 000000000..f8a898636 --- /dev/null +++ b/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py @@ -0,0 +1,125 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +r"""This is to convert checkpoint weight from linen to nnx structure. + +Usage: + +python -m tools/weight_inspector/convert_checkpoint.py --source_path="/original-model/runner_direct_1/checkpoints/14/" --output_path="/converted-model/runner_direct_1/checkpoints/14/" + +""" + +import jax +import orbax.checkpoint as ocp +from typing import Any, Dict +import numpy as np +import sys +import argparse # Import argparse +from etils import epath +import pprint + + + +def load_full_checkpoint(checkpoint_dir: epath.Path) -> Dict[str, Any] | None: + """Loads the entire PyTree checkpoint using Orbax.""" + items_path = checkpoint_dir / 'items' + print(f"Loading full checkpoint from: {items_path}") + if not items_path.exists(): + print(f"Error: Checkpoint items not found: {items_path}") + return None + try: + orbax_checkpointer = ocp.PyTreeCheckpointer() + restored_object = orbax_checkpointer.restore(items_path) + print(f"Successfully restored full checkpoint from {items_path}") + return restored_object + except Exception as e: + print(f"An error occurred during checkpoint restoration: {e}") + return None + +def wrap_array_leaves(tree: Any) -> Any: + """Recursively wraps only JAX/NumPy array leaf nodes in {'value': array} format.""" + def _wrap(leaf): + if isinstance(leaf, (jax.Array, np.ndarray)): + return {'value': leaf} + return leaf # Keep scalars as they are + return jax.tree_util.tree_map(_wrap, tree) + +def main(args): + source_path = epath.Path(args.source_path) + output_path = epath.Path(args.output_path) + + print(f"--- Converting Checkpoint ---") + print(f" Source (V1 - main): {source_path}") + print(f" Output (V2 - modelspy format): {output_path}") + + restored_main = load_full_checkpoint(source_path) + if restored_main is None: + sys.exit(1) + + if 'params' not in restored_main or 'params' not in restored_main['params']: + print("Error: Expected structure {'params': {'params': ...}} not found in source.") + sys.exit(1) + + # 1. Extract the core parameters from the main model + main_core_params = restored_main['params']['params'] + # Wrap only the array leaves within the core parameters + nnx_style_core_params = wrap_array_leaves(main_core_params) + + # 2. Process opt_state: Wrap only array leaves + if 'opt_state' in restored_main: + new_opt_state = wrap_array_leaves(restored_main['opt_state']) + else: + new_opt_state = None + print("Warning: 'opt_state' not found in source checkpoint.") + + # 3. Construct the new state to save, matching the modelspy structure + state_to_save = { + 'params': nnx_style_core_params, + 'opt_state': new_opt_state, + 'step': restored_main.get('step'), # Keep step as a scalar + 'graphdef': None, # Add to match modelspy structure + } + + print("\n--- Structure of State to Save (types) ---") + pprint.pprint(jax.tree_util.tree_map(lambda x: type(x), state_to_save), depth=4) + + save_items_path = output_path / 'items' + print(f"--- Saving converted checkpoint to {save_items_path} ---") + + if jax.process_index() == 0: + output_path.mkdir(parents=True, exist_ok=True) + + # Barrier to ensure directory is created before other processes proceed + if jax.process_count() > 1: + jax.experimental.multihost_utils.sync_global_devices("output_dir_creation") + + checkpointer = ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler()) + try: + checkpointer.save(save_items_path, state_to_save) + checkpointer.wait_until_finished() + print(f"✅ Conversion complete. Saved to {save_items_path}") + except Exception as e: + print(f"❌ Error during saving checkpoint: {e}") + sys.exit(1) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Convert Flax checkpoint format.') + parser.add_argument('--source_path', type=str, required=True, + help='Path to the source "main" model checkpoint directory (containing items/).') + parser.add_argument('--output_path', type=str, required=True, + help='Path to save the converted "modelspy" format checkpoint directory.') + + args = parser.parse_args() + main(args) diff --git a/tools/weight_inspector/weight_inspector.py b/tools/weight_inspector/weight_inspector.py index 8959e69ad..ce50bdddb 100644 --- a/tools/weight_inspector/weight_inspector.py +++ b/tools/weight_inspector/weight_inspector.py @@ -67,5 +67,3 @@ def inspect_weights(left_path, right_path): args = parser.parse_args() inspect_weights(args.lhs, args.rhs) - - args = parser.parse_args() From 12230d896151ee0f0feb681e736d565c02a3cb4d Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Mon, 1 Dec 2025 02:53:23 +0000 Subject: [PATCH 2/5] fix: wrong file name --- tools/weight_inspector/convert_linen_checkpoint_to_nnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py b/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py index f8a898636..5621f96b2 100644 --- a/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py +++ b/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py @@ -17,7 +17,7 @@ Usage: -python -m tools/weight_inspector/convert_checkpoint.py --source_path="/original-model/runner_direct_1/checkpoints/14/" --output_path="/converted-model/runner_direct_1/checkpoints/14/" +python -m tools/weight_inspector/convert_linen_checkpoint_to_nnx.py --source_path="/original-model/runner_direct_1/checkpoints/14/" --output_path="/converted-model/runner_direct_1/checkpoints/14/" """ From c42eb266cecf5977dc25a84ae0248c7e8e6dc936 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Mon, 1 Dec 2025 04:07:36 +0000 Subject: [PATCH 3/5] fix: wrong path --- tools/weight_inspector/compare_checkpoint.py | 4 ++-- tools/weight_inspector/convert_linen_checkpoint_to_nnx.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/weight_inspector/compare_checkpoint.py b/tools/weight_inspector/compare_checkpoint.py index 5f2be8c47..b7179cf2e 100644 --- a/tools/weight_inspector/compare_checkpoint.py +++ b/tools/weight_inspector/compare_checkpoint.py @@ -18,7 +18,7 @@ Usage: -python3 -m tools/weight_inspector/compare_checkpoint.py --lhs /model-left/runner_direct_1/checkpoints/0/items --rhs /model-right/runner_direct_1/checkpoints/0/items +python tools/weight_inspector/compare_checkpoint.py --lhs /model-left/runner_direct_1/checkpoints/0/items --rhs /model-right/runner_direct_1/checkpoints/0/items """ @@ -143,7 +143,7 @@ def compare_quantized_checkpoints(left_path: str, right_path: str, rtol: float = parser.add_argument("--rhs", type=str, required=True) args = parser.parse_args() - are_checkpoints_same = compare_quantized_checkpoints(args.lhs, args.rhg) + are_checkpoints_same = compare_quantized_checkpoints(args.lhs, args.rhs) print(f"\nComparison result: {are_checkpoints_same}") diff --git a/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py b/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py index 5621f96b2..c985d8be8 100644 --- a/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py +++ b/tools/weight_inspector/convert_linen_checkpoint_to_nnx.py @@ -17,7 +17,7 @@ Usage: -python -m tools/weight_inspector/convert_linen_checkpoint_to_nnx.py --source_path="/original-model/runner_direct_1/checkpoints/14/" --output_path="/converted-model/runner_direct_1/checkpoints/14/" +python tools/weight_inspector/convert_linen_checkpoint_to_nnx.py --source_path="/original-model/runner_direct_1/checkpoints/14/" --output_path="/converted-model/runner_direct_1/checkpoints/14/" """ @@ -61,8 +61,8 @@ def main(args): output_path = epath.Path(args.output_path) print(f"--- Converting Checkpoint ---") - print(f" Source (V1 - main): {source_path}") - print(f" Output (V2 - modelspy format): {output_path}") + print(f" Source: {source_path}") + print(f" Output: {output_path}") restored_main = load_full_checkpoint(source_path) if restored_main is None: From fbdc9ca7732b25ca5ebe8b268e084f8b4898bb1d Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Mon, 1 Dec 2025 07:05:05 +0000 Subject: [PATCH 4/5] wrong name --- tools/weight_inspector/compare_checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/weight_inspector/compare_checkpoint.py b/tools/weight_inspector/compare_checkpoint.py index b7179cf2e..6e14b7343 100644 --- a/tools/weight_inspector/compare_checkpoint.py +++ b/tools/weight_inspector/compare_checkpoint.py @@ -33,7 +33,7 @@ def load_params_from_path(checkpoint_dir: str) -> Dict[str, Any] | None: if not checkpoint_dir: raise ValueError("checkpoint_dir must be provided.") - print(f"Loading quantized params checkpoint from: {checkpoint_dir}") + print(f"Loading params checkpoint from: {checkpoint_dir}") orbax_checkpointer = ocp.PyTreeCheckpointer() try: restored_object = orbax_checkpointer.restore(checkpoint_dir) @@ -51,7 +51,7 @@ def get_tree_paths(tree: Any) -> Set[str]: flat_with_path, _ = jax.tree_util.tree_flatten_with_path(tree) return {jax.tree_util.keystr(p) for p, _ in flat_with_path} -def compare_quantized_checkpoints(left_path: str, right_path: str, rtol: float = 1e-3, atol: float = 1e-3) -> bool: +def compare_checkpoints(left_path: str, right_path: str, rtol: float = 1e-3, atol: float = 1e-3) -> bool: print(f"\n--- Comparing Checkpoints ---") print(f" Left checkpoint path: {left_path}") print(f" Right checkpoint path: {right_path}") @@ -143,7 +143,7 @@ def compare_quantized_checkpoints(left_path: str, right_path: str, rtol: float = parser.add_argument("--rhs", type=str, required=True) args = parser.parse_args() - are_checkpoints_same = compare_quantized_checkpoints(args.lhs, args.rhs) + are_checkpoints_same = compare_checkpoints(args.lhs, args.rhs) print(f"\nComparison result: {are_checkpoints_same}") From c84d93fdb12b2eb023af4a0022030640d4936a36 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Mon, 1 Dec 2025 07:40:55 +0000 Subject: [PATCH 5/5] fix: allow numpy and jax array --- tools/weight_inspector/compare_checkpoint.py | 66 ++++++++++++-------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/tools/weight_inspector/compare_checkpoint.py b/tools/weight_inspector/compare_checkpoint.py index 6e14b7343..5e00728e9 100644 --- a/tools/weight_inspector/compare_checkpoint.py +++ b/tools/weight_inspector/compare_checkpoint.py @@ -21,13 +21,13 @@ python tools/weight_inspector/compare_checkpoint.py --lhs /model-left/runner_direct_1/checkpoints/0/items --rhs /model-right/runner_direct_1/checkpoints/0/items """ - import argparse import jax import orbax.checkpoint as ocp from typing import Any, Dict, Set import pprint import numpy as np +import jax.tree_util def load_params_from_path(checkpoint_dir: str) -> Dict[str, Any] | None: @@ -91,41 +91,56 @@ def compare_checkpoints(left_path: str, right_path: str, rtol: float = 1e-3, ato left_values = map_left[key] right_values = map_right[key] - if type(left_values) is not type(right_values): - print(f"❌ Type mismatch at {key}: {type(left_values)} vs {type(right_values)}") - all_equal = False - continue + is_left_array = isinstance(left_values, (jax.Array, np.ndarray)) + is_right_array = isinstance(right_values, (jax.Array, np.ndarray)) - if isinstance(left_values, jax.Array): - if left_values.shape != right_values.shape: - print(f"❌ Shape mismatch at {key}: {left_values.shape} vs {right_values.shape}") + if is_left_array and is_right_array: + try: + # Normalize both to numpy arrays on CPU + left_cpu = np.array(jax.device_get(left_values)) + right_cpu = np.array(jax.device_get(right_values)) + except Exception as e: + print(f"❌ Error during array conversion at {key}: {e}") all_equal = False; continue - if left_values.dtype != right_values.dtype: - print(f"❌ Dtype mismatch at {key}: {left_values.dtype} vs {right_values.dtype}") + + if left_cpu.shape != right_cpu.shape: + print(f"❌ Shape mismatch at {key}: {left_cpu.shape} vs {right_cpu.shape}") all_equal = False; continue + # Basic dtype compatibility check + if left_cpu.dtype != right_cpu.dtype: + print(f"⚠️ Dtype mismatch at {key}: {left_cpu.dtype} vs {right_cpu.dtype}. Attempting numerical comparison.") + try: - left_cpu = jax.device_get(left_values) - right_cpu = jax.device_get(right_values) + if not np.allclose(left_cpu, right_cpu, rtol=rtol, atol=atol): + print(f"❌ Numerical difference in Array at {key} ({left_cpu.dtype} vs {right_cpu.dtype}).") + diff = np.abs(left_cpu - right_cpu) + print(f" Max diff: {np.max(diff)}, Mean diff: {np.mean(diff)}") + all_equal = False + except TypeError as e: + print(f"❌ TypeError during np.allclose at {key} ({left_cpu.dtype} vs {right_cpu.dtype}): {e}") + all_equal = False except Exception as e: - print(f"❌ Error during jax.device_get at {key}: {e}") - all_equal = False; continue + print(f"❌ Error during np.allclose at {key}: {e}") + all_equal = False - if not np.allclose(left_cpu, right_cpu, rtol=rtol, atol=atol): - print(f"❌ Numerical difference in JAX Array at {key}.") - diff = np.abs(left_cpu - right_cpu) - print(f" Max diff: {np.max(diff)}, Mean diff: {np.mean(diff)}") - all_equal = False + elif is_left_array != is_right_array: + print(f"❌ Type mismatch at {key}: {type(left_values)} vs {type(right_values)}") + all_equal = False elif isinstance(left_values, dict): - if left_values != right_values: + if not isinstance(right_values, dict) or left_values != right_values: print(f"❌ Dict difference at {key}:") pprint.pprint(f" Left: {left_values}", width=120) pprint.pprint(f" Right: {right_values}", width=120) all_equal = False elif left_values != right_values: try: - if np.isscalar(left_values) and np.isscalar(right_values) and np.allclose(np.array(left_values), np.array(right_values), rtol=rtol, atol=atol): - continue + # Scalar numerical comparison + if np.isscalar(left_values) and np.isscalar(right_values) and \ + isinstance(left_values, (int, float, np.number)) and \ + isinstance(right_values, (int, float, np.number)): + if np.isclose(float(left_values), float(right_values), rtol=rtol, atol=atol): + continue except (TypeError, ValueError): pass print(f"❌ Value difference at {key}: {left_values} vs {right_values}") @@ -141,10 +156,9 @@ def compare_checkpoints(left_path: str, right_path: str, rtol: float = 1e-3, ato parser = argparse.ArgumentParser() parser.add_argument("--lhs", type=str, required=True) parser.add_argument("--rhs", type=str, required=True) + parser.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance for numerical comparison.") + parser.add_argument("--atol", type=float, default=1e-3, help="Absolute tolerance for numerical comparison.") args = parser.parse_args() - are_checkpoints_same = compare_checkpoints(args.lhs, args.rhs) + are_checkpoints_same = compare_checkpoints(args.lhs, args.rhs, rtol=args.rtol, atol=args.atol) print(f"\nComparison result: {are_checkpoints_same}") - - -