Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Check out our [Read The Docs site](https://maxtext.readthedocs.io/en/latest/) or

See our installation guide to [install MaxText with pip](https://maxtext.readthedocs.io/en/latest/guides/install_maxtext.html).

## Decoupled mode
See our guide on running MaxText in decoupled mode, without any GCP dependencies in [Decoupled Mode Guide](https://maxtext.readthedocs.io/en/latest/guides/run_maxtext/decoupled_mode.html).

## 🔥 Latest news 🔥

* \[September 26, 2025\] Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage.
Expand Down
43 changes: 43 additions & 0 deletions dependencies/requirements/requirements_decoupled_jax_0_7.1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
absl_py>=2.3.1
aqtp>=0.9.0
chex>=0.1.90
datasets>=4.2.0
etils>=1.13.0
evaluate>=0.4.6
flax
grain>=0.2.12
grpcio>=1.75.1
huggingface_hub>=0.35.3
jaxtyping>=0.3.3
jsonlines>=4.0.0
matplotlib>=3.10.3
ml_collections>=1.1.0
ml_dtypes>=0.5.3
nltk>=3.9.2
numpy>=2.0.2
omegaconf>=2.3.0
optax>=0.2.6
orbax-checkpoint>=0.11.25
pandas>=2.3.3
pathwaysutils>=0.1.3
pillow>=11.3.0
protobuf>=5.29.5
psutil>=7.0.0
pytest>=8.4.1
PyYAML>=6.0.3
Requests>=2.32.5
qwix>=0.1.1
safetensors>=0.6.2
sentencepiece>=0.2.1
setuptools>=80.9.0
tabulate>=0.9.0
tensorflow>=2.19.1
tensorflow_text>=2.19.0
tensorflow_datasets>=4.9.9
tensorstore>=0.1.76
tiktoken>=0.12.0
tqdm>=4.67.1
transformers>=4.57.0
urllib3>=2.5.0
jax==0.7.1
git+https://github.com/google/tunix.git
1 change: 1 addition & 0 deletions docs/guides/run_maxtext.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ run_maxtext/run_maxtext_localhost.md
run_maxtext/run_maxtext_single_host_gpu.md
run_maxtext/run_maxtext_via_xpk.md
run_maxtext/run_maxtext_via_pathways.md
run_maxtext/decoupled_mode.md
```
86 changes: 86 additions & 0 deletions docs/guides/run_maxtext/decoupled_mode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
<!--
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->


# Decoupled Mode (No Google Cloud Dependencies)

Set `DECOUPLE_GCLOUD=TRUE` to run MaxText tests and local development without any Google Cloud SDK, `gs://` buckets, JetStream, or Vertex AI integrations.

When enabled:
* Skips external integration tests with markers:
* `external_serving` (`jetstream`, `serving`, `decode_server`)
* `external_training` (`goodput`)
* `decoupled` – Applied by `tests/conftest.py` to tests that are runnable in decoupled mode (i.e. not skipped for TPU or external markers).
* Production / serving entrypoints (`decode.py`, `maxengine_server.py`, `maxengine_config.py`, tokenizer access in `maxengine.py`) **fail fast with a clear RuntimeError** when decoupled. This prevents accidentally running partial serving logic locally when decoupled mode is ON.
* Import-time safety is preserved by lightweight stubs returned from `decouple.py` (so modules import cleanly); only active use of missing functionality raises.
* Conditionally replaces dataset paths in certain tests to point at minimal local datasets.
* Uses a local base output directory (users can override with `LOCAL_BASE_OUTPUT`).
* All tests that previously hard-coded `configs/base.yml` now use the helper `get_test_config_path()` from `tests/test_utils.py`. This helper ensures usage of `decoupled_base_test.yml`

Minimal datasets included (checked into the repo):
* ArrayRecord shards: generated via `python local_datasets/get_minimal_c4_en_dataset.py`,
located in `local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-{train,validation}.array_record-*`
* Parquet (HF style): generated via `python local_datasets/get_minimal_hf_c4_parquet.py`,
located in `local_datasets/c4_en_dataset_minimal/hf/c4`


Run a local smoke test fully offline:
```bash
export DECOUPLE_GCLOUD=TRUE
pytest -k train_gpu_smoke_test -q
```

Optional environment variables:
* `LOCAL_GCLOUD_PROJECT` - placeholder project string (default: `local-maxtext-project`).
* `LOCAL_BASE_OUTPUT` - override default local output directory used in tests.

## Centralized Decoupling API (`gcloud_stub.py`)

MaxText exposes a single module `MaxText.gcloud_stub` to avoid scattering environment checks:

```python
from MaxText.gcloud_stub import is_decoupled, cloud_diagnostics, jetstream

if is_decoupled():
# Skip optional integrations or use local fallbacks
pass

# Cloud diagnostics (returns diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration)
diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = cloud_diagnostics()

# JetStream (serving) components
config_lib, engine_api, token_utils, tokenizer_api, token_params_ns = jetstream()
TokenizerParameters = getattr(token_params_ns, "TokenizerParameters", object)
```

Behavior when `DECOUPLE_GCLOUD=TRUE`:
* `is_decoupled()` returns True.
* Each helper returns lightweight stubs whose attributes are safe to access; calling methods raises a clear `RuntimeError` only when actually invoked.
* Prevents import-time failures for optional dependencies (JetStream).

## Guidelines:
* Prefer calling `jetstream()` / `cloud_diagnostics()` once at module import and branching on `is_decoupled()` for functionality that truly requires the dependency.
* Use `is_decoupled()` to avoid direct `os.environ["DECOUPLE_GCLOUD"]` checking.
* Use `get_test_config_path()` instead of hard-coded `base.yml`.
* Prefer conditional local fallbacks for cloud buckets and avoid introducing direct `gs://...` paths.
* Please add the appropriate external dependency marker (`external_serving` or `external_training`) for new tests. Prefer the smallest scope instead of module-wide `pytestmark` when only a part of a file needs an external dependency.
* Tests add a `decoupled` marker if DECOUPLE_GCLOUD && not marked with external dependency markers. Run tests with:
```
pytest -m decoupled -vv tests
```

This centralized approach keeps optional integrations cleanly separated from core MaxText logic, making local development (e.g. on ROCm/NVIDIA GPUs) frictionless.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"configDescription": "Local minimal C4 EN subset",
"configName": "en",
"description": "Local minimal C4 English subset.",
"fileFormat": "tfrecord",
"location": {
"urls": [
"https://www.tensorflow.org/datasets/catalog/c4"
]
},
"moduleName": "__main__",
"name": "__local_c4_builder",
"splits": [
{
"filepathTemplate": "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}",
"name": "train",
"shardLengths": [
"125",
"125",
"125",
"125",
"125",
"125",
"125",
"125"
]
},
{
"filepathTemplate": "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}",
"name": "validation",
"shardLengths": [
"100",
"100"
]
}
],
"version": "3.0.1"
}
11 changes: 11 additions & 0 deletions local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/features.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"featuresDict": {
"features": {
"text": {
"pythonClassName": "tensorflow_datasets.core.features.text_feature.Text",
"text": {}
}
}
},
"pythonClassName": "tensorflow_datasets.core.features.features_dict.FeaturesDict"
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"configDescription": "Local minimal C4 EN subset",
"configName": "en",
"description": "Local minimal C4 English subset.",
"fileFormat": "tfrecord",
"location": {
"urls": [
"https://www.tensorflow.org/datasets/catalog/c4"
]
},
"moduleName": "__main__",
"name": "__local_c4_builder",
"splits": [
{
"filepathTemplate": "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}",
"name": "train",
"shardLengths": [
"125",
"125",
"125",
"125",
"125",
"125",
"125",
"125"
]
},
{
"filepathTemplate": "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}",
"name": "validation",
"shardLengths": [
"100",
"100"
]
}
],
"version": "3.1.0"
}
11 changes: 11 additions & 0 deletions local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/features.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"featuresDict": {
"features": {
"text": {
"pythonClassName": "tensorflow_datasets.core.features.text_feature.Text",
"text": {}
}
}
},
"pythonClassName": "tensorflow_datasets.core.features.features_dict.FeaturesDict"
}
Binary file not shown.
Binary file not shown.
131 changes: 131 additions & 0 deletions local_datasets/convert_arrayrecord_to_tfrecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convert minimal C4 ArrayRecord shards to TFRecord shards.

This script scans a version directory that contains ArrayRecord shards
and produces TFRecord files with TFDS-compatible shard names.

Example usage:
python local_datasets/convert_arrayrecord_to_tfrecord.py \
--version-dir local_datasets/c4_en_dataset_minimal/c4/en/3.0.1 \
--builder-name __local_c4_builder \
--force

Options:
--dry-run Only show planned conversions.
--force Overwrite existing TFRecord output files.

Dependencies:
array_record (Python bindings)
tensorflow

Limitations:
Records are copied verbatim; compression is not applied.
"""
from __future__ import annotations
import os
import argparse
import glob
import sys
from typing import List

try:
from array_record.python.array_record_module import ArrayRecordReader
except ModuleNotFoundError:
print("Error: array_record module not found. Install appropriate package before running.")
sys.exit(1)

import tensorflow as tf


def discover_shards(version_dir: str, split: str) -> List[str]:
"""Return sorted list of ArrayRecord shard paths for a split."""
pattern = os.path.join(version_dir, f"c4-{split}.array_record-*")
return sorted(glob.glob(pattern))


def parse_shard_numbers(fname: str) -> tuple[str, str]:
"""Extract shard index and total from a shard filename.

Example: c4-train.array_record-00003-of-00008 -> ("00003", "00008").
"""
base = os.path.basename(fname)
parts = base.split("-")
# last two parts are shard index and total, e.g. 00003, of, 00008
shard_idx = parts[-3]
total = parts[-1]
return shard_idx, total


def convert_shard(arrayrecord_path: str, output_path: str, force: bool) -> None:
"""Convert a single ArrayRecord shard into a TFRecord file.

If the output exists and ``force`` is False, the function skips conversion.
"""
if os.path.exists(output_path) and not force:
print(f"Skip existing: {output_path}")
return

reader = ArrayRecordReader(arrayrecord_path)
count = reader.num_records()
written = 0
batch_size = 1024

with tf.io.TFRecordWriter(output_path) as writer:
start = 0
while start < count:
end = min(start + batch_size, count)
# reader.read(start, end) returns list of records in [start,end)
batch = reader.read(start, end)
for rec in batch:
writer.write(rec)
written += 1
start = end

print(f"Converted {arrayrecord_path} -> {output_path} ({written} / {count} records)")


def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--version-dir", required=True, help="Directory like c4_en_dataset_minimal/c4/en/3.0.1")
ap.add_argument("--builder-name", default="__local_c4_builder", help="Prefix used for TFRecord shard filenames.")
ap.add_argument("--dry-run", action="store_true", help="Only list planned conversions.")
ap.add_argument("--force", action="store_true", help="Overwrite existing TFRecord shards if present.")
args = ap.parse_args()

if not os.path.isdir(args.version_dir):
print(f"Version directory not found: {args.version_dir}")
sys.exit(1)

for split in ["train", "validation"]:
shards = discover_shards(args.version_dir, split)
if not shards:
print(f"No ArrayRecord shards found for split '{split}' in {args.version_dir}")
continue
print(f"Found {len(shards)} {split} ArrayRecord shards.")
for shard in shards:
shard_idx, total = parse_shard_numbers(shard)
tfrec_name = f"{args.builder_name}-{split}.tfrecord-{shard_idx}-of-{total}"
out_path = os.path.join(args.version_dir, tfrec_name)
if args.dry_run:
print(f"Would create: {out_path} from {shard}")
else:
convert_shard(shard, out_path, args.force)

print("Done.")


if __name__ == "__main__":
main()
Loading
Loading