Skip to content

Commit ed517cf

Browse files
Merge pull request #2681 from AI-Hypercomputer:mohit/grpo_doc
PiperOrigin-RevId: 836820005
2 parents 6e9eb9d + 2b28c74 commit ed517cf

File tree

5 files changed

+185
-55
lines changed

5 files changed

+185
-55
lines changed

docs/tutorials/grpo.md

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,64 @@ Primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-projec
4646

4747
You can also locally git clone [tunix](https://github.com/google/tunix) and install using the instructions [here](https://github.com/google/tunix?tab=readme-ov-file#installation). Similarly install [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) from source following the instructions [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source)
4848

49+
## Setup the following environment variables before running GRPO
50+
51+
Setup following environment variables before running GRPO
52+
53+
```bash
54+
# -- Model configuration --
55+
export HF_MODEL='llama3.1-8b-Instruct'
56+
export MODEL='llama3.1-8b'
57+
export TOKENIZER='meta-llama/Llama-3.1-8B-Instruct'
58+
export HF_TOKEN=<Hugging Face access token>
59+
60+
# -- MaxText configuration --
61+
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
62+
63+
export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
64+
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/0/items
65+
```
66+
67+
## Get your model checkpoint
68+
69+
You can convert a Hugging Face checkpoint to MaxText format using the `src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you have a pre-trained model from Hugging Face that you want to use with MaxText.
70+
71+
First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the --lazy_load_tensors flag to reduce memory usage during conversion. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
72+
73+
```bash
74+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
75+
76+
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
77+
model_name=${HF_MODEL} \
78+
hf_access_token=${HF_TOKEN} \
79+
base_output_directory=${MAXTEXT_CKPT_PATH} \
80+
scan_layers=True hardware=cpu skip_jax_distributed_system=true
81+
82+
# Example of converting Llama3.1-70B using --lazy_load_tensor=true which uses around 86GB of RAM
83+
84+
python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
85+
model_name=llama3.1-70b \
86+
hf_access_token=${HF_TOKEN} \
87+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME} \
88+
scan_layers=True \
89+
hardware=cpu skip_jax_distributed_system=true \
90+
--lazy_load_tensors=true
91+
```
92+
93+
94+
4995
## Run GRPO
5096

5197
Finally, run the command
5298

5399
```
54100
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
55-
model_name=llama3.1-8b \
56-
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
57-
load_parameters_path=gs://path/to/checkpoint/0/items \
58-
run_name=$WORKLOAD \
59-
base_output_directory=$OUTPUT_PATH \
60-
hf_access_token=$HF_TOKEN
101+
model_name=${MODEL} \
102+
tokenizer_path=${TOKENIZER} \
103+
load_parameters_path=${MAXTEXT_CKPT_PATH} \
104+
run_name=${RUN_NAME} \
105+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
106+
hf_access_token=${HF_TOKEN}
61107
```
62108

63109
The overview of the what this run will do is as follows:

docs/tutorials/grpo_with_pathways.md

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,50 @@ Furthermore, we use Pathways for [orchestration](https://cloud.google.com/ai-hyp
2929
Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), but
3030
recommend creating the virtual environment outside the `maxtext` directory.
3131

32+
33+
## Setup the following environment variables before running GRPO
34+
35+
Setup following environment variables before running GRPO
36+
37+
```bash
38+
# -- Model configuration --
39+
export HF_MODEL='llama3.1-70b-Instruct'
40+
export MODEL='llama3.1-70b'
41+
export TOKENIZER='meta-llama/Llama-3.1-70B-Instruct'
42+
export HF_TOKEN=<Hugging Face access token>
43+
44+
# -- MaxText configuration --
45+
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
46+
export RUN_NAME=llama-3-70b-grpo
47+
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/0/items
48+
49+
# -- Workload configuration --
50+
export WORKLOAD=${RUN_NAME}
51+
export TPU_TYPE='v5p-128'
52+
export TPU_CLUSTER=<cluster name>
53+
export PROJECT_ID=<GCP project ID>
54+
export ZONE=<zone name>
55+
```
56+
57+
## Get your model checkpoint
58+
59+
You can convert a Hugging Face checkpoint to MaxText format using the `src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you have a pre-trained model from Hugging Face that you want to use with MaxText.
60+
61+
First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the `--lazy_load_tensors` flag to reduce memory usage during conversion. \
62+
For example, converting a Llama3.1-70B model scanned checkpoint using `--lazy_load_tensors=true` will use around 200GB of RAM and completes in ~10 mins. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
63+
64+
```bash
65+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
66+
67+
# using --lazy_load_tensors=true here will reduce the memory usage. eg, Llama3.1-70B conversion takes around 86GB of RAM
68+
python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
69+
model_name=${HF_MODEL} \
70+
hf_access_token=${HF_TOKEN} \
71+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME} \
72+
scan_layers=true checkpoint_storage_use_ocdbt=false checkpoint_storage_use_zarr3=false \
73+
skip_jax_distributed_system=true --lazy_load_tensors=true
74+
```
75+
3276
## Build and Upload MaxText Docker Image with Tunix, vLLM, tpu-inference dependencies
3377

3478
### Installing stable releases of tunix and vllm-tpu
@@ -45,28 +89,30 @@ You can also use `bash dependencies/scripts/docker_build_dependency_image.sh MOD
4589
### Install from locally git cloned repo's
4690

4791
You can also locally git clone [tunix](https://github.com/google/tunix), [tpu-inference](https://github.com/vllm-project/tpu-inference), [vllm](https://github.com/vllm-project/vllm.git) and then use the following command to build a docker image using them:
48-
`bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local`
92+
```
93+
bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local
94+
```
4995

5096
### Upload the dependency docker image along with MaxText code
5197
```
52-
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=path/to/gcr.io
98+
bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}
5399
```
54100

55101
### Submit your jobs
56102

57103
Please create a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster), and you can submit the `train_rl.py` script via [XPK](https://github.com/AI-Hypercomputer/xpk)
58104
```
59105
xpk workload create-pathways --workload $WORKLOAD \
60-
--docker-image path/to/gcr.io:latest --cluster $TPU_CLUSTER \
106+
--docker-image <path/to/gcr.io> --cluster $TPU_CLUSTER \
61107
--tpu-type=$TPU_TYPE --num-slices=1 --zone=$ZONE \
62108
--project=$PROJECT_ID --priority=high \
63-
--command "HF_TOKEN=$HF_TOKEN TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' # Llama3.1-70B-Instruct
109+
--command "TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
64110
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
65-
model_name=llama3.1-70b \
66-
tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
67-
load_parameters_path=gs://path/to/checkpoint/0/items \
68-
run_name=$WORKLOAD \
69-
base_output_directory=$OUTPUT_PATH \
111+
model_name=${MODEL} \
112+
tokenizer_path=${TOKENIZER} \
113+
load_parameters_path=${MAXTEXT_CKPT_PATH} \
114+
run_name=${RUN_NAME} \
115+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
70116
hf_access_token=$HF_TOKEN"
71117
```
72118

src/MaxText/pyconfig.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import sys
2020
from typing import Any
21+
import copy
2122

2223
import jax
2324
import jax.numpy as jnp
@@ -151,6 +152,13 @@ def __init__(self, pydantic_config: types.MaxTextConfig):
151152

152153
object.__setattr__(self, "_flat_config", final_dict)
153154

155+
def __deepcopy__(self, memo):
156+
new_pydantic_config = copy.deepcopy(self._pydantic_config, memo)
157+
return HyperParameters(new_pydantic_config)
158+
159+
def tree_flatten(self):
160+
return (), self
161+
154162
def __getattr__(self, attr: str) -> Any:
155163
"""Provides attribute-style access to the final configuration dictionary."""
156164
if attr in self._flat_config:

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,19 +279,34 @@ def __repr__(self):
279279

280280

281281
class LazyTensorHandler(type_handlers.NumpyHandler):
282-
"""Custom Orbax handler for LazyTensor to avoid typestr collision with np.ndarray."""
282+
"""
283+
Custom Orbax handler for LazyTensor.
284+
285+
It masquerades as a standard NumpyHandler so that the resulting checkpoint
286+
has the standard 'array_metadatas' structure and can be loaded by
287+
standard MaxText instances.
288+
"""
283289

284-
def typestr(self):
285-
return "LazyTensor"
290+
async def serialize(self, value, *args, **kwargs):
291+
# MATERIALIZE: Trigger the lazy load (__array__) explicitly before saving.
292+
# This ensures the parent NumpyHandler receives a real np.ndarray.
293+
if hasattr(value, "__array__"):
294+
value = np.array(value)
295+
296+
return await super().serialize(value, *args, **kwargs)
286297

287298

288299
# Register LazyTensor with the custom handler.
289300
# It's safe to register this globally even if eager loading is used.
290-
type_handlers.register_type_handler(LazyTensor, LazyTensorHandler())
301+
type_handlers.register_type_handler(LazyTensor, LazyTensorHandler(), override=True)
291302

292303

293304
def _build_multi_axis_stacked_tensor(
294-
hf_source_keys: List[List[str]], tensor_getter_fn: Callable[[str], np.ndarray], hook_fns: Any
305+
hf_source_keys: List[List[str]],
306+
tensor_getter_fn: Callable[[str], np.ndarray],
307+
hook_fns: Any,
308+
target_shape: tuple,
309+
config,
295310
) -> np.ndarray:
296311
"""Builds a MaxText tensor by stacking HF weights along two axes (experts and layers).
297312
@@ -303,18 +318,24 @@ def _build_multi_axis_stacked_tensor(
303318
Outer list iterates experts, inner list iterates layers.
304319
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
305320
hook_fns: The hook function(s) to apply to each individual weight.
321+
target_shape: The final shape of the target MaxText tensor.
322+
config: The MaxText pyconfig object.
306323
307324
Returns:
308325
The final, assembled NumPy array for the MaxText parameter.
309326
"""
310327
all_expert_tensors = []
328+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
329+
# For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:]
330+
mt_slice_shape = target_shape[2:]
331+
311332
# Outer loop iterates through experts
312333
for layer_keys_for_expert in hf_source_keys:
313334
layer_tensors_for_expert = []
314335
# Inner loop iterates through layers for the current expert
315336
for hf_key_single in layer_keys_for_expert:
316337
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
317-
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, None, hook_fns)
338+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
318339
layer_tensors_for_expert.append(processed_hf_tensor)
319340
all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0))
320341
return np.stack(all_expert_tensors, axis=0)
@@ -514,7 +535,14 @@ def _loader(getter, key, shape, hook):
514535
# Stacked mapping
515536
if isinstance(hf_source_keys_or_key[0], list):
516537
# Case 2: Multi-Axis Stacked
517-
load_fn = partial(_build_multi_axis_stacked_tensor, hf_source_keys_or_key, tensor_getter, hook_fn)
538+
load_fn = partial(
539+
_build_multi_axis_stacked_tensor,
540+
hf_source_keys_or_key,
541+
tensor_getter,
542+
hook_fn,
543+
mt_target_shape_final,
544+
config,
545+
)
518546
else:
519547
# Case 3: Single-Axis Stacked
520548
load_fn = partial(

tests/train_using_ragged_dot_smoke_test.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -37,39 +37,41 @@ class Train(parameterized.TestCase):
3737
def test_tiny_config(self, quantization: str):
3838
test_tmpdir = os.environ.get("TEST_TMPDIR", gettempdir())
3939
outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", test_tmpdir)
40-
train_main([
41-
None,
42-
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
43-
f"base_output_directory={test_tmpdir}",
44-
"run_name=ragged_dot_smoke_test",
45-
"base_emb_dim=128",
46-
"base_num_query_heads=4",
47-
"base_num_kv_heads=4",
48-
"base_mlp_dim=128",
49-
"base_moe_mlp_dim=128",
50-
"base_num_decoder_layers=8",
51-
"head_dim=128",
52-
# TODO(b/441100085): When changing the decoder_block we might
53-
# need to adjust the tiling.
54-
"decoder_block=deepseek",
55-
"attention_type=mla",
56-
"num_experts=2",
57-
# Enable sparse_matmul.
58-
"sparse_matmul=True",
59-
# Enable ragged_dot.
60-
"megablox=False",
61-
f'quantization="{quantization}"',
62-
"use_qwix_quantization=True",
63-
"per_device_batch_size=2",
64-
"max_target_length=1024",
65-
"dataset_type=synthetic",
66-
"steps=10",
67-
"enable_checkpointing=False",
68-
"enable_goodput_recording=False",
69-
"enable_checkpoint_cloud_logger=False",
70-
"monitor_goodput=False",
71-
f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}",
72-
])
40+
train_main(
41+
[
42+
None,
43+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
44+
f"base_output_directory={test_tmpdir}",
45+
"run_name=ragged_dot_smoke_test",
46+
"base_emb_dim=128",
47+
"base_num_query_heads=4",
48+
"base_num_kv_heads=4",
49+
"base_mlp_dim=128",
50+
"base_moe_mlp_dim=128",
51+
"base_num_decoder_layers=8",
52+
"head_dim=128",
53+
# TODO(b/441100085): When changing the decoder_block we might
54+
# need to adjust the tiling.
55+
"decoder_block=deepseek",
56+
"attention_type=mla",
57+
"num_experts=2",
58+
# Enable sparse_matmul.
59+
"sparse_matmul=True",
60+
# Enable ragged_dot.
61+
"megablox=False",
62+
f'quantization="{quantization}"',
63+
"use_qwix_quantization=True",
64+
"per_device_batch_size=2",
65+
"max_target_length=1024",
66+
"dataset_type=synthetic",
67+
"steps=10",
68+
"enable_checkpointing=False",
69+
"enable_goodput_recording=False",
70+
"enable_checkpoint_cloud_logger=False",
71+
"monitor_goodput=False",
72+
f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}",
73+
]
74+
)
7375

7476

7577
if __name__ == "__main__":

0 commit comments

Comments
 (0)