Skip to content

Commit acd9e07

Browse files
Merge pull request #2722 from AI-Hypercomputer:pydantic_fix
PiperOrigin-RevId: 834913985
2 parents c2a68d0 + ccb4995 commit acd9e07

File tree

8 files changed

+21
-20
lines changed

8 files changed

+21
-20
lines changed

dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ RUN pip install numba==0.61.2
3030
# Install vLLM for Jax and TPUs
3131
RUN pip install vllm-tpu
3232

33+
RUN pip install --no-deps qwix==0.1.4
34+
3335
RUN if [ "$MODE" = "post-training-experimental" ]; then \
3436
pip uninstall -y jax jaxlib libtpu && \
3537
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \

dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ RUN pip install keyring keyrings.google-artifactregistry-auth
2828
RUN pip install numba==0.61.2
2929

3030
COPY tunix /tunix
31+
RUN pip uninstall -y google-tunix
3132
RUN pip install -e /tunix --no-cache-dir
3233

3334

@@ -49,6 +50,7 @@ RUN pip install -e /tpu-inference --no-cache-dir --pre \
4950
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
5051
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
5152

53+
RUN pip install --no-deps qwix==0.1.4
5254

5355
RUN if [ "$MODE" = "post-training-experimental" ]; then \
5456
echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \

src/MaxText/configs/rl.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ train_fraction: 1.0
9797

9898
eval_interval: 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
9999

100-
num_epochs: 1 # can potentially train for more epochs
100+
num_epoch: 1 # can potentially train for more epochs
101101

102102
learning_rate: 3e-6
103103
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.

src/MaxText/configs/types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from enum import Enum
2020
from math import prod
2121
from tempfile import gettempdir
22-
from typing import Any, NewType, Literal
22+
from typing import Any, NewType, Literal, Optional
2323
import datetime
2424
import logging
2525
import math
@@ -837,9 +837,9 @@ class HfDataset(BaseModel):
837837

838838
hf_path: str = Field("", description="Path or name of the Hugging Face dataset.")
839839
hf_data_dir: PathStr = Field("", description="Data directory for the HF dataset.")
840-
hf_train_files: str = Field("", description="Files for the HF training split.")
840+
hf_train_files: Optional[str] = Field(None, description="Files for the HF training split.")
841841
hf_eval_split: str = Field("", description="Name of the HF evaluation split.")
842-
hf_eval_files: str = Field("", description="Files for the HF evaluation split.")
842+
hf_eval_files: Optional[str] = Field(None, description="Files for the HF evaluation split.")
843843
hf_access_token: None | str = Field(None, description="Hugging Face API access token.")
844844

845845

src/MaxText/examples/install_tunix_vllm_requirement.sh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
set -e
2020
set -x
2121

22-
uv pip uninstall -y jax jaxlib libtpu
22+
uv pip uninstall jax jaxlib libtpu
2323

2424
uv pip install aiohttp==3.12.15
2525

@@ -28,6 +28,4 @@ uv pip install vllm-tpu
2828

2929
uv pip install numba==0.61.2
3030

31-
uv pip install qwix==0.1.1
32-
33-
uv pip install flax==0.11.1
31+
uv pip install --no-deps qwix==0.1.4

src/MaxText/model_creation_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng
117117
return model
118118

119119

120-
def create_nnx_model(config, mesh=None, devices=None, model_mode=None, rng_key=None):
120+
def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None):
121121
"""Creates a NNX model with sharded parameters, possibly loading from a checkpoint."""
122122

123123
def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None):

src/MaxText/pyconfig.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,6 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
9797
pydantic_kwargs = {}
9898
valid_fields = types.MaxTextConfig.model_fields.keys()
9999

100-
# This is a workaround for tests that use `dataset_type='hf'` but do not
101-
# specify `tokenizer_type='huggingface'`, which they should.
102-
if raw_keys.get("dataset_type") == "hf" and "tokenizer_type" not in raw_keys:
103-
raw_keys["tokenizer_type"] = "huggingface"
104-
105100
for key, value in raw_keys.items():
106101
if key not in valid_fields:
107102
logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key))
@@ -119,7 +114,11 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
119114
if key == "data_sharding" and isinstance(new_value, list) and new_value and isinstance(new_value[0], str):
120115
new_value = [new_value]
121116

122-
if key in ("run_name", "hf_train_files", "hf_eval_files") and new_value is None:
117+
# An empty value provided in the configuration is treated as None
118+
if key in ("hf_train_files", "hf_eval_files") and new_value == "":
119+
new_value = None
120+
121+
if key == "run_name" and new_value is None:
123122
new_value = ""
124123

125124
pydantic_kwargs[key] = new_value

src/MaxText/rl/train_rl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def get_maxtext_model(config, devices=None):
100100
# Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
101101
# load_parameters_path=/path/to/your/output/directory/0/items
102102
"""
103-
model, mesh = model_creation_utils.create_nnx_model(config, devices)
103+
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
104104
with mesh:
105105
tunix_model = TunixMaxTextAdapter(base_model=model)
106106
tunix_model.config = None
@@ -238,7 +238,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
238238
trainer_config.num_batches
239239
* trainer_config.num_iterations
240240
* trainer_config.train_fraction
241-
* trainer_config.num_epochs
241+
* trainer_config.num_epoch
242242
)
243243

244244
# ====== Data ======
@@ -260,10 +260,10 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
260260
)[: trainer_config.num_batches]
261261

262262
if trainer_config.train_fraction == 1.0:
263-
train_dataset = dataset.repeat(trainer_config.num_epochs)
263+
train_dataset = dataset.repeat(trainer_config.num_epoch)
264264
else:
265265
train_dataset = dataset[: int(len(dataset) * trainer_config.train_fraction)]
266-
train_dataset = train_dataset.repeat(trainer_config.num_epochs)
266+
train_dataset = train_dataset.repeat(trainer_config.num_epoch)
267267

268268
test_dataset = get_dataset(model_tokenizer, trainer_config, test_data_dir, trainer_config.eval_split).batch(
269269
trainer_config.batch_size
@@ -416,7 +416,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
416416
lambda **kwargs: utils_rl.check_answer(tmvp_config=trainer_config, **kwargs),
417417
lambda **kwargs: utils_rl.check_numbers(tmvp_config=trainer_config, **kwargs),
418418
],
419-
grpo_config=grpo_config,
419+
algo_config=grpo_config,
420420
)
421421

422422
# Before we train the model, let's evaluate the model on the test set so we can

0 commit comments

Comments
 (0)