diff --git a/README.md b/README.md index a801421b6..0217cda10 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,9 @@ Check out our [Read The Docs site](https://maxtext.readthedocs.io/en/latest/) or See our installation guide to [install MaxText with pip](https://maxtext.readthedocs.io/en/latest/guides/install_maxtext.html). +## Decoupled mode +See our guide on running MaxText in decoupled mode, without any GCP dependencies in [Decoupled Mode Guide](https://maxtext.readthedocs.io/en/latest/guides/run_maxtext/decoupled_mode.html). + ## 🔥 Latest news 🔥 * \[September 26, 2025\] Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage. diff --git a/dependencies/requirements/requirements_decoupled_jax_0_7.1.txt b/dependencies/requirements/requirements_decoupled_jax_0_7.1.txt new file mode 100644 index 000000000..ec16b7ca6 --- /dev/null +++ b/dependencies/requirements/requirements_decoupled_jax_0_7.1.txt @@ -0,0 +1,43 @@ +absl_py>=2.3.1 +aqtp>=0.9.0 +chex>=0.1.90 +datasets>=4.2.0 +etils>=1.13.0 +evaluate>=0.4.6 +flax +grain>=0.2.12 +grpcio>=1.75.1 +huggingface_hub>=0.35.3 +jaxtyping>=0.3.3 +jsonlines>=4.0.0 +matplotlib>=3.10.3 +ml_collections>=1.1.0 +ml_dtypes>=0.5.3 +nltk>=3.9.2 +numpy>=2.0.2 +omegaconf>=2.3.0 +optax>=0.2.6 +orbax-checkpoint>=0.11.25 +pandas>=2.3.3 +pathwaysutils>=0.1.3 +pillow>=11.3.0 +protobuf>=5.29.5 +psutil>=7.0.0 +pytest>=8.4.1 +PyYAML>=6.0.3 +Requests>=2.32.5 +qwix>=0.1.1 +safetensors>=0.6.2 +sentencepiece>=0.2.1 +setuptools>=80.9.0 +tabulate>=0.9.0 +tensorflow>=2.19.1 +tensorflow_text>=2.19.0 +tensorflow_datasets>=4.9.9 +tensorstore>=0.1.76 +tiktoken>=0.12.0 +tqdm>=4.67.1 +transformers>=4.57.0 +urllib3>=2.5.0 +jax==0.7.1 +git+https://github.com/google/tunix.git \ No newline at end of file diff --git a/docs/guides/run_maxtext.md b/docs/guides/run_maxtext.md index 6864d2bba..b7bb32519 100644 --- a/docs/guides/run_maxtext.md +++ b/docs/guides/run_maxtext.md @@ -7,4 +7,5 @@ run_maxtext/run_maxtext_localhost.md run_maxtext/run_maxtext_single_host_gpu.md run_maxtext/run_maxtext_via_xpk.md run_maxtext/run_maxtext_via_pathways.md +run_maxtext/decoupled_mode.md ``` diff --git a/docs/guides/run_maxtext/decoupled_mode.md b/docs/guides/run_maxtext/decoupled_mode.md new file mode 100644 index 000000000..930231b1b --- /dev/null +++ b/docs/guides/run_maxtext/decoupled_mode.md @@ -0,0 +1,86 @@ + + + +# Decoupled Mode (No Google Cloud Dependencies) + +Set `DECOUPLE_GCLOUD=TRUE` to run MaxText tests and local development without any Google Cloud SDK, `gs://` buckets, JetStream, or Vertex AI integrations. + +When enabled: +* Skips external integration tests with markers: + * `external_serving` (`jetstream`, `serving`, `decode_server`) + * `external_training` (`goodput`) +* `decoupled` – Applied by `tests/conftest.py` to tests that are runnable in decoupled mode (i.e. not skipped for TPU or external markers). +* Production / serving entrypoints (`decode.py`, `maxengine_server.py`, `maxengine_config.py`, tokenizer access in `maxengine.py`) **fail fast with a clear RuntimeError** when decoupled. This prevents accidentally running partial serving logic locally when decoupled mode is ON. +* Import-time safety is preserved by lightweight stubs returned from `decouple.py` (so modules import cleanly); only active use of missing functionality raises. +* Conditionally replaces dataset paths in certain tests to point at minimal local datasets. +* Uses a local base output directory (users can override with `LOCAL_BASE_OUTPUT`). +* All tests that previously hard-coded `configs/base.yml` now use the helper `get_test_config_path()` from `tests/test_utils.py`. This helper ensures usage of `decoupled_base_test.yml` + +Minimal datasets included (checked into the repo): +* ArrayRecord shards: generated via `python local_datasets/get_minimal_c4_en_dataset.py`, + located in `local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-{train,validation}.array_record-*` +* Parquet (HF style): generated via `python local_datasets/get_minimal_hf_c4_parquet.py`, + located in `local_datasets/c4_en_dataset_minimal/hf/c4` + + +Run a local smoke test fully offline: +```bash +export DECOUPLE_GCLOUD=TRUE +pytest -k train_gpu_smoke_test -q +``` + +Optional environment variables: +* `LOCAL_GCLOUD_PROJECT` - placeholder project string (default: `local-maxtext-project`). +* `LOCAL_BASE_OUTPUT` - override default local output directory used in tests. + +## Centralized Decoupling API (`gcloud_stub.py`) + +MaxText exposes a single module `MaxText.gcloud_stub` to avoid scattering environment checks: + +```python +from MaxText.gcloud_stub import is_decoupled, cloud_diagnostics, jetstream + +if is_decoupled(): + # Skip optional integrations or use local fallbacks + pass + +# Cloud diagnostics (returns diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration) +diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = cloud_diagnostics() + +# JetStream (serving) components +config_lib, engine_api, token_utils, tokenizer_api, token_params_ns = jetstream() +TokenizerParameters = getattr(token_params_ns, "TokenizerParameters", object) +``` + +Behavior when `DECOUPLE_GCLOUD=TRUE`: +* `is_decoupled()` returns True. +* Each helper returns lightweight stubs whose attributes are safe to access; calling methods raises a clear `RuntimeError` only when actually invoked. +* Prevents import-time failures for optional dependencies (JetStream). + +## Guidelines: +* Prefer calling `jetstream()` / `cloud_diagnostics()` once at module import and branching on `is_decoupled()` for functionality that truly requires the dependency. +* Use `is_decoupled()` to avoid direct `os.environ["DECOUPLE_GCLOUD"]` checking. +* Use `get_test_config_path()` instead of hard-coded `base.yml`. +* Prefer conditional local fallbacks for cloud buckets and avoid introducing direct `gs://...` paths. +* Please add the appropriate external dependency marker (`external_serving` or `external_training`) for new tests. Prefer the smallest scope instead of module-wide `pytestmark` when only a part of a file needs an external dependency. +* Tests add a `decoupled` marker if DECOUPLE_GCLOUD && not marked with external dependency markers. Run tests with: +``` +pytest -m decoupled -vv tests +``` + +This centralized approach keeps optional integrations cleanly separated from core MaxText logic, making local development (e.g. on ROCm/NVIDIA GPUs) frictionless. + diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00000-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00000-of-00008 new file mode 100644 index 000000000..de257cc5e Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00000-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00001-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00001-of-00008 new file mode 100644 index 000000000..24e90164b Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00001-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00002-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00002-of-00008 new file mode 100644 index 000000000..1756785e7 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00002-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00003-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00003-of-00008 new file mode 100644 index 000000000..caabe9c66 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00003-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00004-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00004-of-00008 new file mode 100644 index 000000000..3bbb0f3b4 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00004-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00005-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00005-of-00008 new file mode 100644 index 000000000..3b3e81a35 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00005-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00006-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00006-of-00008 new file mode 100644 index 000000000..1b8f4ad1d Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00006-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00007-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00007-of-00008 new file mode 100644 index 000000000..4ccc28606 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-train.tfrecord-00007-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-validation.tfrecord-00000-of-00002 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-validation.tfrecord-00000-of-00002 new file mode 100644 index 000000000..3bcab079a Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-validation.tfrecord-00000-of-00002 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-validation.tfrecord-00001-of-00002 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-validation.tfrecord-00001-of-00002 new file mode 100644 index 000000000..81a8647bd Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/__local_c4_builder-validation.tfrecord-00001-of-00002 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00000-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00000-of-00008 new file mode 100644 index 000000000..7bb8ab8df Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00000-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00001-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00001-of-00008 new file mode 100644 index 000000000..09cc1164e Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00001-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00002-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00002-of-00008 new file mode 100644 index 000000000..131833b7d Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00002-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00003-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00003-of-00008 new file mode 100644 index 000000000..70e9cfa8e Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00003-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00004-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00004-of-00008 new file mode 100644 index 000000000..8f981971a Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00004-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00005-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00005-of-00008 new file mode 100644 index 000000000..bf742e583 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00005-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00006-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00006-of-00008 new file mode 100644 index 000000000..8fa1565b5 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00006-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00007-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00007-of-00008 new file mode 100644 index 000000000..38160b896 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-train.array_record-00007-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-validation.array_record-00000-of-00002 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-validation.array_record-00000-of-00002 new file mode 100644 index 000000000..24230734d Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-validation.array_record-00000-of-00002 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-validation.array_record-00001-of-00002 b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-validation.array_record-00001-of-00002 new file mode 100644 index 000000000..c5ac72183 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-validation.array_record-00001-of-00002 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/dataset_info.json b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/dataset_info.json new file mode 100644 index 000000000..8b7c1359f --- /dev/null +++ b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/dataset_info.json @@ -0,0 +1,38 @@ +{ + "configDescription": "Local minimal C4 EN subset", + "configName": "en", + "description": "Local minimal C4 English subset.", + "fileFormat": "tfrecord", + "location": { + "urls": [ + "https://www.tensorflow.org/datasets/catalog/c4" + ] + }, + "moduleName": "__main__", + "name": "__local_c4_builder", + "splits": [ + { + "filepathTemplate": "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}", + "name": "train", + "shardLengths": [ + "125", + "125", + "125", + "125", + "125", + "125", + "125", + "125" + ] + }, + { + "filepathTemplate": "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}", + "name": "validation", + "shardLengths": [ + "100", + "100" + ] + } + ], + "version": "3.0.1" +} \ No newline at end of file diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/features.json b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/features.json new file mode 100644 index 000000000..bf13dbfb1 --- /dev/null +++ b/local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/features.json @@ -0,0 +1,11 @@ +{ + "featuresDict": { + "features": { + "text": { + "pythonClassName": "tensorflow_datasets.core.features.text_feature.Text", + "text": {} + } + } + }, + "pythonClassName": "tensorflow_datasets.core.features.features_dict.FeaturesDict" +} \ No newline at end of file diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00000-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00000-of-00008 new file mode 100644 index 000000000..de257cc5e Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00000-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00001-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00001-of-00008 new file mode 100644 index 000000000..24e90164b Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00001-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00002-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00002-of-00008 new file mode 100644 index 000000000..1756785e7 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00002-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00003-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00003-of-00008 new file mode 100644 index 000000000..caabe9c66 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00003-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00004-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00004-of-00008 new file mode 100644 index 000000000..3bbb0f3b4 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00004-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00005-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00005-of-00008 new file mode 100644 index 000000000..3b3e81a35 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00005-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00006-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00006-of-00008 new file mode 100644 index 000000000..1b8f4ad1d Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00006-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00007-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00007-of-00008 new file mode 100644 index 000000000..4ccc28606 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-train.tfrecord-00007-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-validation.tfrecord-00000-of-00002 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-validation.tfrecord-00000-of-00002 new file mode 100644 index 000000000..3bcab079a Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-validation.tfrecord-00000-of-00002 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-validation.tfrecord-00001-of-00002 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-validation.tfrecord-00001-of-00002 new file mode 100644 index 000000000..81a8647bd Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/__local_c4_builder-validation.tfrecord-00001-of-00002 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00000-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00000-of-00008 new file mode 100644 index 000000000..7bb8ab8df Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00000-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00001-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00001-of-00008 new file mode 100644 index 000000000..09cc1164e Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00001-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00002-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00002-of-00008 new file mode 100644 index 000000000..131833b7d Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00002-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00003-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00003-of-00008 new file mode 100644 index 000000000..70e9cfa8e Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00003-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00004-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00004-of-00008 new file mode 100644 index 000000000..8f981971a Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00004-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00005-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00005-of-00008 new file mode 100644 index 000000000..bf742e583 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00005-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00006-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00006-of-00008 new file mode 100644 index 000000000..8fa1565b5 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00006-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00007-of-00008 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00007-of-00008 new file mode 100644 index 000000000..38160b896 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-train.array_record-00007-of-00008 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-validation.array_record-00000-of-00002 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-validation.array_record-00000-of-00002 new file mode 100644 index 000000000..24230734d Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-validation.array_record-00000-of-00002 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-validation.array_record-00001-of-00002 b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-validation.array_record-00001-of-00002 new file mode 100644 index 000000000..c5ac72183 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/c4-validation.array_record-00001-of-00002 differ diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/dataset_info.json b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/dataset_info.json new file mode 100644 index 000000000..2e762de0f --- /dev/null +++ b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/dataset_info.json @@ -0,0 +1,38 @@ +{ + "configDescription": "Local minimal C4 EN subset", + "configName": "en", + "description": "Local minimal C4 English subset.", + "fileFormat": "tfrecord", + "location": { + "urls": [ + "https://www.tensorflow.org/datasets/catalog/c4" + ] + }, + "moduleName": "__main__", + "name": "__local_c4_builder", + "splits": [ + { + "filepathTemplate": "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}", + "name": "train", + "shardLengths": [ + "125", + "125", + "125", + "125", + "125", + "125", + "125", + "125" + ] + }, + { + "filepathTemplate": "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}", + "name": "validation", + "shardLengths": [ + "100", + "100" + ] + } + ], + "version": "3.1.0" +} \ No newline at end of file diff --git a/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/features.json b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/features.json new file mode 100644 index 000000000..bf13dbfb1 --- /dev/null +++ b/local_datasets/c4_en_dataset_minimal/c4/en/3.1.0/features.json @@ -0,0 +1,11 @@ +{ + "featuresDict": { + "features": { + "text": { + "pythonClassName": "tensorflow_datasets.core.features.text_feature.Text", + "text": {} + } + } + }, + "pythonClassName": "tensorflow_datasets.core.features.features_dict.FeaturesDict" +} \ No newline at end of file diff --git a/local_datasets/c4_en_dataset_minimal/hf/c4/c4-train-00000-of-01637.parquet b/local_datasets/c4_en_dataset_minimal/hf/c4/c4-train-00000-of-01637.parquet new file mode 100644 index 000000000..cac7de029 Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/hf/c4/c4-train-00000-of-01637.parquet differ diff --git a/local_datasets/c4_en_dataset_minimal/hf/c4/c4-validation-00000-of-01637.parquet b/local_datasets/c4_en_dataset_minimal/hf/c4/c4-validation-00000-of-01637.parquet new file mode 100644 index 000000000..32547edfe Binary files /dev/null and b/local_datasets/c4_en_dataset_minimal/hf/c4/c4-validation-00000-of-01637.parquet differ diff --git a/local_datasets/convert_arrayrecord_to_tfrecord.py b/local_datasets/convert_arrayrecord_to_tfrecord.py new file mode 100644 index 000000000..b7572f7ef --- /dev/null +++ b/local_datasets/convert_arrayrecord_to_tfrecord.py @@ -0,0 +1,131 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert minimal C4 ArrayRecord shards to TFRecord shards. + +This script scans a version directory that contains ArrayRecord shards +and produces TFRecord files with TFDS-compatible shard names. + +Example usage: + python local_datasets/convert_arrayrecord_to_tfrecord.py \ + --version-dir local_datasets/c4_en_dataset_minimal/c4/en/3.0.1 \ + --builder-name __local_c4_builder \ + --force + +Options: + --dry-run Only show planned conversions. + --force Overwrite existing TFRecord output files. + +Dependencies: + array_record (Python bindings) + tensorflow + +Limitations: + Records are copied verbatim; compression is not applied. +""" +from __future__ import annotations +import os +import argparse +import glob +import sys +from typing import List + +try: + from array_record.python.array_record_module import ArrayRecordReader +except ModuleNotFoundError: + print("Error: array_record module not found. Install appropriate package before running.") + sys.exit(1) + +import tensorflow as tf + + +def discover_shards(version_dir: str, split: str) -> List[str]: + """Return sorted list of ArrayRecord shard paths for a split.""" + pattern = os.path.join(version_dir, f"c4-{split}.array_record-*") + return sorted(glob.glob(pattern)) + + +def parse_shard_numbers(fname: str) -> tuple[str, str]: + """Extract shard index and total from a shard filename. + + Example: c4-train.array_record-00003-of-00008 -> ("00003", "00008"). + """ + base = os.path.basename(fname) + parts = base.split("-") + # last two parts are shard index and total, e.g. 00003, of, 00008 + shard_idx = parts[-3] + total = parts[-1] + return shard_idx, total + + +def convert_shard(arrayrecord_path: str, output_path: str, force: bool) -> None: + """Convert a single ArrayRecord shard into a TFRecord file. + + If the output exists and ``force`` is False, the function skips conversion. + """ + if os.path.exists(output_path) and not force: + print(f"Skip existing: {output_path}") + return + + reader = ArrayRecordReader(arrayrecord_path) + count = reader.num_records() + written = 0 + batch_size = 1024 + + with tf.io.TFRecordWriter(output_path) as writer: + start = 0 + while start < count: + end = min(start + batch_size, count) + # reader.read(start, end) returns list of records in [start,end) + batch = reader.read(start, end) + for rec in batch: + writer.write(rec) + written += 1 + start = end + + print(f"Converted {arrayrecord_path} -> {output_path} ({written} / {count} records)") + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--version-dir", required=True, help="Directory like c4_en_dataset_minimal/c4/en/3.0.1") + ap.add_argument("--builder-name", default="__local_c4_builder", help="Prefix used for TFRecord shard filenames.") + ap.add_argument("--dry-run", action="store_true", help="Only list planned conversions.") + ap.add_argument("--force", action="store_true", help="Overwrite existing TFRecord shards if present.") + args = ap.parse_args() + + if not os.path.isdir(args.version_dir): + print(f"Version directory not found: {args.version_dir}") + sys.exit(1) + + for split in ["train", "validation"]: + shards = discover_shards(args.version_dir, split) + if not shards: + print(f"No ArrayRecord shards found for split '{split}' in {args.version_dir}") + continue + print(f"Found {len(shards)} {split} ArrayRecord shards.") + for shard in shards: + shard_idx, total = parse_shard_numbers(shard) + tfrec_name = f"{args.builder_name}-{split}.tfrecord-{shard_idx}-of-{total}" + out_path = os.path.join(args.version_dir, tfrec_name) + if args.dry_run: + print(f"Would create: {out_path} from {shard}") + else: + convert_shard(shard, out_path, args.force) + + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/local_datasets/generate_tfds_metadata.py b/local_datasets/generate_tfds_metadata.py new file mode 100644 index 000000000..8d472b2d6 --- /dev/null +++ b/local_datasets/generate_tfds_metadata.py @@ -0,0 +1,142 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate minimal TFDS metadata for existing local C4 ArrayRecord shards. + +Usage: + python local_datasets/generate_tfds_metadata.py \ + --root local_datasets/c4_en_dataset_minimal \ + --version 3.1.0 \ + --source-version 3.0.1 \ + --force + +This script creates a tiny TFDS builder and outputs the ``dataset_info.json`` and +``features.json`` files. + +After running, you can point TFDS to ``--root`` and load with +``dataset_name='c4/en:3.1.0'``. +""" +from __future__ import annotations +import os +import argparse +import tensorflow_datasets as tfds # type: ignore + + +def ensure_symlink(root: str, source_version: str, version: str) -> str: + """Ensure a symlink exists from source_version to version under root/c4/en. + + Returns the target version directory path. + """ + src = os.path.join(root, "c4", "en", source_version) + dst = os.path.join(root, "c4", "en", version) + if not os.path.isdir(src): + raise FileNotFoundError(f"Source version directory not found: {src}") + if not os.path.lexists(dst): + try: + os.symlink(src, dst) + print(f"Created symlink {dst} -> {src}") + except OSError as exc: + print(f"Symlink creation failed (continuing): {exc}") + else: + print(f"Symlink already exists: {dst}") + return dst + + +def write_metadata(root: str, version_dir: str, dataset_version: str, force: bool = False) -> None: + """Write TFDS ``dataset_info.json`` and ``features.json`` for local C4 shards.""" + info_path = os.path.join(version_dir, "dataset_info.json") + if os.path.exists(info_path) and not force: + print("dataset_info.json already exists; skipping overwrite (use --force to regenerate).") + return + + # Discover shards (we assume they exist and are correct; counts are fixed) + num_shards_train = 8 + num_shards_val = 2 + exact_train_records = 1000 + exact_val_records = 200 + + train_records_per_shard = exact_train_records // num_shards_train + val_records_per_shard = exact_val_records // num_shards_val + train_shard_lengths = [train_records_per_shard] * num_shards_train + val_shard_lengths = [val_records_per_shard] * num_shards_val + + train_split = tfds.core.SplitInfo(name="train", shard_lengths=train_shard_lengths, num_bytes=0) + val_split = tfds.core.SplitInfo(name="validation", shard_lengths=val_shard_lengths, num_bytes=0) + + class _LocalC4Builder(tfds.core.GeneratorBasedBuilder): + """Tiny builder used only to materialize TFDS metadata on disk.""" + + VERSION = tfds.core.Version(dataset_version) + BUILDER_CONFIGS = [tfds.core.BuilderConfig(name="en", version=VERSION, description="Local minimal C4 EN subset")] + + def _info(self) -> tfds.core.DatasetInfo: # type: ignore[override] + info = tfds.core.DatasetInfo( + builder=self, + description="Local minimal C4 English subset.", + features=tfds.features.FeaturesDict({"text": tfds.features.Text()}), + homepage="https://www.tensorflow.org/datasets/catalog/c4", + citation="", + ) + info.set_splits({"train": train_split, "validation": val_split}) + return info + + def _split_generators(self, dl_manager): # type: ignore[override] + """No actual generation; data already exists on disk.""" + del dl_manager + return [] + + def _generate_examples(self): # type: ignore[override] + """No example generation; placeholder to satisfy API.""" + yield from () + + builder = _LocalC4Builder(data_dir=root) + info = builder.info + + # Write canonical files (features.json + dataset_info.json) + info.write_to_directory(version_dir) + print(f"Wrote TFDS dataset_info & features to {version_dir}") + + +def main() -> None: + """CLI entry point for generating TFDS metadata.""" + ap = argparse.ArgumentParser() + ap.add_argument( + "--root", + required=True, + help="Root directory containing c4/en/ shards", + ) + ap.add_argument( + "--version", + default="3.1.0", + help="Target version to expose via TFDS", + ) + ap.add_argument( + "--source-version", + default="3.0.1", + help="Existing version directory with shards", + ) + ap.add_argument( + "--force", + action="store_true", + help="Overwrite existing dataset_info.json if present", + ) + args = ap.parse_args() + + target_dir = ensure_symlink(args.root, args.source_version, args.version) + write_metadata(args.root, target_dir, args.version, force=args.force) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/local_datasets/get_minimal_c4_en_dataset.py b/local_datasets/get_minimal_c4_en_dataset.py new file mode 100644 index 000000000..8cd74dfec --- /dev/null +++ b/local_datasets/get_minimal_c4_en_dataset.py @@ -0,0 +1,455 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Create a small local C4 English dataset from MinIO for offline tests. + +This utility connects to a MinIO instance, finds small source shards (ArrayRecord +or TFRecord), and writes a set of small ArrayRecord shards locally that mimic +the original dataset's sharding. It is intended for test isolation and is not +meant to be a production data pipeline. +""" + +import os +import glob +import sys +import argparse +from typing import List + +from minio import Minio +from minio.error import S3Error + +# ArrayRecord Python bindings +from array_record.python.array_record_module import ArrayRecordWriter, ArrayRecordReader +import tensorflow as tf + + +# ----------------------- +# Configurable parameters +# ----------------------- + +# ------------ MinIO Connection Config (override via env) ------------ +MINIO_ENDPOINT = os.environ.get("MINIO_ENDPOINT", "minio-frameworks.amd.com") +MINIO_ACCESS_KEY = os.environ.get("MINIO_ACCESS_KEY", "hidden") +MINIO_SECRET_KEY = os.environ.get("MINIO_SECRET_KEY", "hidden") +MINIO_SECURE = os.environ.get("MINIO_SECURE", "true").lower() == "true" +BUCKET = os.environ.get("MINIO_C4_BUCKET", "datasets.dl") + +# Versions of c4/en to sample +# VERSIONS = ["3.0.1", "3.0.5", "3.0.7", "3.0.8", "3.0.9"] +VERSIONS = ["3.0.1"] + +# Local output base +LOCAL_BASE = "local_datasets/c4_en_dataset_minimal/c4/en" + +# Shard counts (simulate real behavior) +NUM_SHARDS_TRAIN = 8 +NUM_SHARDS_VAL = 2 + +# Record caps: adjust to control total content size +EXACT_TRAIN_RECORDS = 1000 +EXACT_VAL_RECORDS = 200 + +# Per-output-shard hard cap (bytes) so each file stays under target size +# Adjust as needed; 20 MiB per shard keeps total per version < 50MB. +MAX_OUTPUT_SHARD_BYTES = 20 * 1024 * 1024 # 20 MiB per shard + +# Per-version soft cap (for info/warning) +MAX_VERSION_BYTES = 50 * 1024 * 1024 # 50 MiB + +# Temp download cap for TFRecord range GET (no need download the full shard) +MAX_TEMP_DOWNLOAD_BYTES = 200 * 1024 * 1024 # 200 MiB + +# Prefixes in MinIO +ARRAY_RECORD_TRAIN_PREFIX = "c4/en/{ver}/c4-train.array_record-" +ARRAY_RECORD_VAL_PREFIX = "c4/en/{ver}/c4-validation.array_record-" +TFRECORD_TRAIN_PREFIX = "c4/en/{ver}/c4-train.tfrecord-" +TFRECORD_VAL_PREFIX = "c4/en/{ver}/c4-validation.tfrecord-" + + +def ensure_dir(path: str) -> None: + os.makedirs(path, exist_ok=True) + + +def list_matching(client: Minio, bucket: str, prefix: str) -> List: + """Return a sorted list of objects under prefix.""" + return sorted( + (obj for obj in client.list_objects(bucket, prefix=prefix, recursive=True)), + key=lambda o: o.object_name, + ) + + +def pick_smallest(objects): + """Pick the smallest object by size. + + Falls back gracefully if size is unavailable. + """ + return min(objects, key=lambda o: getattr(o, "size", float("inf"))) + + +def download_shard_with_optional_range( + client, + bucket, + obj, + tmp_dir, + allow_range: bool = False, + max_bytes: int = MAX_TEMP_DOWNLOAD_BYTES, +): + """Download shard to a temp file. + + If ``allow_range`` is True and ``obj.size > max_bytes``, download only the + first ``max_bytes`` bytes (safe for TFRecord; not for ArrayRecord). + """ + ensure_dir(tmp_dir) + local_tmp = os.path.join(tmp_dir, os.path.basename(obj.object_name)) + + if allow_range and getattr(obj, "size", None) and obj.size > max_bytes: + # Read first max_bytes bytes for TFRecord; iterator stops at + # incomplete record. + response = client.get_object(bucket, obj.object_name, offset=0, length=max_bytes) + try: + with open(local_tmp, "wb") as f: + for d in response.stream(32 * 1024): + f.write(d) + finally: + response.close() + response.release_conn() + else: + client.fget_object(bucket, obj.object_name, local_tmp) + return local_tmp + + +def write_sharded_with_byte_caps_from_arrayrecord( + src_path, + dst_dir, + split_name, + num_shards, + max_total_records, + max_shard_bytes, +): + """Write multiple ArrayRecord shards from a single ArrayRecord source. + + Records are distributed round-robin with per-shard byte caps. + """ + ensure_dir(dst_dir) + writers = [] + shard_bytes = [0] * num_shards + for i in range(num_shards): + shard_name = f"c4-{split_name}.array_record-{i:05d}-of-{num_shards:05d}" + writers.append(ArrayRecordWriter(os.path.join(dst_dir, shard_name), "group_size:1")) + + reader = ArrayRecordReader(src_path) + n = min(max_total_records, reader.num_records()) + shard_idx = 0 + written = 0 + for i in range(n): + rec = reader.read(i) + rec_len = len(rec) + # If current shard would exceed cap, move to next shard. + if shard_bytes[shard_idx] + rec_len > max_shard_bytes: + shard_idx = (shard_idx + 1) % num_shards + # If next shard is also full, stop early. + if shard_bytes[shard_idx] + rec_len > max_shard_bytes: + break + writers[shard_idx].write(rec) + shard_bytes[shard_idx] += rec_len + written += 1 + shard_idx = (shard_idx + 1) % num_shards + + for writer in writers: + writer.close() + + print( + f"[{split_name}] Wrote {written} records across {num_shards} shards; " + f"per-shard sizes: {[round(b/1024/1024, 2) for b in shard_bytes]} MiB", + ) + return written, shard_bytes + + +def write_sharded_with_byte_caps_from_tfrecord( + src_path, + dst_dir, + split_name, + num_shards, + max_total_records, + max_shard_bytes, +): + """Write multiple ArrayRecord shards from a TFRecord source file.""" + ensure_dir(dst_dir) + writers = [] + shard_bytes = [0] * num_shards + for i in range(num_shards): + shard_name = f"c4-{split_name}.array_record-{i:05d}-of-{num_shards:05d}" + writers.append(ArrayRecordWriter(os.path.join(dst_dir, shard_name), "group_size:1")) + + shard_idx = 0 + count = 0 + for raw_example in tf.compat.v1.io.tf_record_iterator(src_path): + rec_len = len(raw_example) + if shard_bytes[shard_idx] + rec_len > max_shard_bytes: + shard_idx = (shard_idx + 1) % num_shards + if shard_bytes[shard_idx] + rec_len > max_shard_bytes: + break + writers[shard_idx].write(raw_example) + shard_bytes[shard_idx] += rec_len + count += 1 + shard_idx = (shard_idx + 1) % num_shards + if count >= max_total_records: + break + + for writer in writers: + writer.close() + + print( + f"[{split_name}] Wrote {count} records across {num_shards} shards; " + f"per-shard sizes: {[round(b/1024/1024, 2) for b in shard_bytes]} MiB", + ) + return count, shard_bytes + + +def compute_dir_size_bytes(dir_path: str, patterns: List[str]) -> int: + """Sum file sizes for all files matching provided glob patterns in dir_path.""" + total = 0 + for patt in patterns: + for path in glob.glob(os.path.join(dir_path, patt)): + try: + total += os.path.getsize(path) + except OSError: + pass + return total + + +def main() -> None: + """CLI entry point for building a minimal C4/en dataset from MinIO.""" + # Parse command-line arguments. + parser = argparse.ArgumentParser( + description="Create minimal c4/en dataset shards from MinIO Instance", + ) + parser.add_argument( + "--force", + action="store_true", + default=False, + help="Force overwrite all existing dataset files without prompting", + ) + args = parser.parse_args() + + # Use TF v1-style record iterator. + tf.compat.v1.disable_eager_execution() + + # Check which versions already exist. + existing_versions = [] + for ver in VERSIONS: + local_version_dir = os.path.join(LOCAL_BASE, ver) + if os.path.exists(local_version_dir): + shard_files = glob.glob(os.path.join(local_version_dir, "c4-*.array_record-*")) + if shard_files: + existing_versions.append(ver) + + if existing_versions: + if args.force: + print(f"Force mode: Will overwrite existing versions: {existing_versions}") + else: + print(f"Found existing versions: {existing_versions}") + # Check if all versions exist to avoid MinIO connection. + if set(existing_versions) == set(VERSIONS): + print("All versions already exist. Nothing to do.") + print("Use --force to regenerate all versions.") + sys.exit(0) + print("Will skip these and only generate missing versions.") + print("Use --force to overwrite all versions.\n") + + client = Minio( + MINIO_ENDPOINT, + access_key=MINIO_ACCESS_KEY, + secret_key=MINIO_SECRET_KEY, + secure=MINIO_SECURE, + ) + + if not client.bucket_exists(BUCKET): + print(f"Bucket '{BUCKET}' does not exist.") + sys.exit(1) + + print("Bucket exists. Starting minimal dataset creation...") + ensure_dir(LOCAL_BASE) + + print("Listing c4/en top-level entries (non-recursive):") + for obj in client.list_objects(BUCKET, prefix="c4/en", recursive=False): + print("-", obj.object_name) + + print("\nListing recursively (first 200 entries):") + i = 0 + for obj in client.list_objects(BUCKET, prefix="c4/en", recursive=True): + print(obj.object_name) + i += 1 + if i >= 200: + break + + for ver in VERSIONS: + local_version_dir = os.path.join(LOCAL_BASE, ver) + + # Skip existing versions unless force mode is enabled. + if not args.force and ver in existing_versions: + print(f"\nSkipping existing version {ver}") + continue + + print(f"\nProcessing c4/en:{ver}") + ensure_dir(local_version_dir) + + # Find source shards for train/validation (prefer ArrayRecord). + train_src_objs = list_matching( + client, + BUCKET, + ARRAY_RECORD_TRAIN_PREFIX.format(ver=ver), + ) + val_src_objs = list_matching( + client, + BUCKET, + ARRAY_RECORD_VAL_PREFIX.format(ver=ver), + ) + + train_is_arrayrec = True + val_is_arrayrec = True + + if not train_src_objs: + train_src_objs = list_matching( + client, + BUCKET, + TFRECORD_TRAIN_PREFIX.format(ver=ver), + ) + train_is_arrayrec = False + if not val_src_objs: + val_src_objs = list_matching( + client, + BUCKET, + TFRECORD_VAL_PREFIX.format(ver=ver), + ) + val_is_arrayrec = False + + if not train_src_objs: + print(f"Warning: No train shards found for {ver}. Skipping this version.") + continue + if not val_src_objs: + print( + "Warning: No validation shards found for " f"{ver}. Skipping validation for this version.", + ) + continue + + # Pick the smallest shard per split to minimize download. + smallest_train = pick_smallest(train_src_objs) + smallest_val = pick_smallest(val_src_objs) + + # Download one shard per split to a temp folder. + tmp_dir = os.path.join(local_version_dir, "_tmp_download") + try: + train_src_local = download_shard_with_optional_range( + client, + BUCKET, + smallest_train, + tmp_dir, + allow_range=not train_is_arrayrec, + max_bytes=MAX_TEMP_DOWNLOAD_BYTES, + ) + val_src_local = download_shard_with_optional_range( + client, + BUCKET, + smallest_val, + tmp_dir, + allow_range=not val_is_arrayrec, + max_bytes=MAX_TEMP_DOWNLOAD_BYTES, + ) + except S3Error as exc: + print(f"Download error for version {ver}: {exc}") + # Clean up and skip. + try: + if os.path.exists(tmp_dir): + for filename in os.listdir(tmp_dir): + os.remove(os.path.join(tmp_dir, filename)) + os.rmdir(tmp_dir) + except OSError: + pass + continue + + # Write minimal multi-shard ArrayRecord files with per-shard caps. + try: + if train_is_arrayrec: + write_sharded_with_byte_caps_from_arrayrecord( + train_src_local, + local_version_dir, + "train", + NUM_SHARDS_TRAIN, + EXACT_TRAIN_RECORDS, + MAX_OUTPUT_SHARD_BYTES, + ) + else: + write_sharded_with_byte_caps_from_tfrecord( + train_src_local, + local_version_dir, + "train", + NUM_SHARDS_TRAIN, + EXACT_TRAIN_RECORDS, + MAX_OUTPUT_SHARD_BYTES, + ) + + if val_is_arrayrec: + write_sharded_with_byte_caps_from_arrayrecord( + val_src_local, + local_version_dir, + "validation", + NUM_SHARDS_VAL, + EXACT_VAL_RECORDS, + MAX_OUTPUT_SHARD_BYTES, + ) + else: + write_sharded_with_byte_caps_from_tfrecord( + val_src_local, + local_version_dir, + "validation", + NUM_SHARDS_VAL, + EXACT_VAL_RECORDS, + MAX_OUTPUT_SHARD_BYTES, + ) + + # Post-write size check. + total_bytes = compute_dir_size_bytes( + local_version_dir, patterns=["c4-train.array_record-*", "c4-validation.array_record-*"] + ) + mb = total_bytes / (1024 * 1024) + print(f"Total size for {ver}: {mb:.2f} MiB") + if total_bytes > MAX_VERSION_BYTES: + print( + f"Note: {ver} exceeds {MAX_VERSION_BYTES/(1024*1024):.0f} MiB. " + "Consider reducing records or MAX_OUTPUT_SHARD_BYTES.", + ) + finally: + # Clean up temp downloads. + try: + if os.path.exists(train_src_local): + os.remove(train_src_local) + if os.path.exists(val_src_local): + os.remove(val_src_local) + if os.path.isdir(tmp_dir): + for filename in os.listdir(tmp_dir): + os.remove(os.path.join(tmp_dir, filename)) + os.rmdir(tmp_dir) + except OSError as cleanup_err: + print(f"Cleanup warning: {cleanup_err}") + + print("\nDone. Verify local directories:") + for ver in VERSIONS: + print(f"- {os.path.join(LOCAL_BASE, ver)}") + for path in sorted(glob.glob(os.path.join(LOCAL_BASE, ver, "c4-*.array_record-*"))): + print(f" {os.path.basename(path)}") + + +if __name__ == "__main__": + main() diff --git a/local_datasets/get_minimal_hf_c4_parquet.py b/local_datasets/get_minimal_hf_c4_parquet.py new file mode 100644 index 000000000..7d3aac7c8 --- /dev/null +++ b/local_datasets/get_minimal_hf_c4_parquet.py @@ -0,0 +1,175 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal C4/en TFRecord -> Parquet converter. + +Fetch the first train & validation TFRecord 00000-of shard for a version and +sample rows into two tiny parquet files with fixed output names for the usage +in tests/grain_data_processing_test.py, tests/hf_data_processing_test.py, +tests/train_tests.py: + c4-train-00000-of-01637.parquet + c4-validation-00000-of-01637.parquet +""" + +import argparse +import os +from pathlib import Path + +from minio import Minio +import pyarrow as pa +import pyarrow.parquet as pq + +import tensorflow as tf + +# ---------------- Environment / Defaults ---------------- +MINIO_ENDPOINT = os.environ.get("MINIO_ENDPOINT", "minio-frameworks.amd.com") +MINIO_ACCESS_KEY = os.environ.get("MINIO_ACCESS_KEY", "hidden") +MINIO_SECRET_KEY = os.environ.get("MINIO_SECRET_KEY", "hidden") +MINIO_SECURE = os.environ.get("MINIO_SECURE", "true").lower() == "true" +BUCKET = os.environ.get("MINIO_C4_BUCKET", "datasets.dl") +SCRIPT_DIR = Path(__file__).parent + + +def download_object(client: Minio, obj, dest_path: Path) -> Path: + """Download an object from MinIO to ``dest_path`` and return the path.""" + data = client.get_object(BUCKET, obj.object_name) + try: + with dest_path.open("wb") as f: + for chunk in data.stream(128 * 1024): + f.write(chunk) + finally: + data.close() + data.release_conn() + return dest_path + + +def write_parquet(path: Path, rows: list[str], force: bool = False) -> None: + """Write ``rows`` to a Parquet file at ``path``. + + If ``force`` is False and the file exists, it is left untouched. + """ + if not force and path.exists(): + print(f"[skip] {path} exists") + return + # If force is set and path exists, remove it first + if force and path.exists(): + path.unlink() + + # Normalize & drop empties again defensively. + rows = [r.strip() for r in rows if isinstance(r, str) and r.strip()] + table = pa.Table.from_pydict({"text": rows}) + pq.write_table(table, path, compression="ZSTD") + print(f"[write] {path} rows={len(rows)} size_kib={path.stat().st_size/1024:.1f}") + + +def sample_tfrecord(path: Path, cap: int) -> list[str]: + """Sample up to ``cap`` records from a TFRecord, extracting the ``text`` feature.""" + feature_spec = {"text": tf.io.FixedLenFeature([], tf.string)} + rows: list[str] = [] + for raw in tf.data.TFRecordDataset(str(path)).take(cap): + parsed = tf.io.parse_single_example(raw, feature_spec) + txt = parsed["text"].numpy().decode("utf-8", "ignore").strip() + if txt: + rows.append(txt) + return rows + + +def main() -> None: + """CLI entry point to generate tiny Parquet files from minimal C4 TFRecords.""" + parser = argparse.ArgumentParser( + description="Minimal C4 TFRecord -> parquet generator", + ) + parser.add_argument("--version", default="3.0.1") + parser.add_argument("--train-rows", type=int, default=800) + parser.add_argument("--val-rows", type=int, default=160) + parser.add_argument( + "--output-dir", + default=str(SCRIPT_DIR / "c4_en_dataset_minimal" / "hf" / "c4"), + ) + parser.add_argument( + "--force", + action="store_true", + default=False, + help="Force overwrite existing parquet files", + ) + args = parser.parse_args() + + # Resolve output paths first so we can early stop. + out_dir = Path(args.output_dir) + if not out_dir.exists(): + out_dir.mkdir(parents=True, exist_ok=True) + train_out = out_dir / "c4-train-00000-of-01637.parquet" + val_out = out_dir / "c4-validation-00000-of-01637.parquet" + + if not args.force and train_out.exists() and val_out.exists(): + print("Both output parquet files already exist; skipping (no MinIO connection needed).") + print("Use --force to regenerate the files.") + return + + client = Minio( + MINIO_ENDPOINT, + access_key=MINIO_ACCESS_KEY, + secret_key=MINIO_SECRET_KEY, + secure=MINIO_SECURE, + ) + if not client.bucket_exists(BUCKET): + print("Bucket missing; abort.") + return + + ver = args.version + # List to find the first train 00000-of shard. + train_prefix = f"c4/en/{ver}/c4-train.tfrecord-00000-of-" + train_obj = None + for obj in client.list_objects(BUCKET, prefix=train_prefix, recursive=False): + train_obj = obj + break + if not train_obj: + print("Train 00000-of shard not found; abort.") + return + + val_prefix = f"c4/en/{ver}/c4-validation.tfrecord-00000-of-" + val_obj = None + for obj in client.list_objects(BUCKET, prefix=val_prefix, recursive=False): + val_obj = obj + break + if not val_obj: + print("Validation 00000-of shard not found; abort.") + return + print( + f"Using train object {train_obj.object_name} and validation object " f"{val_obj.object_name}.", + ) + + tmp_train = out_dir.parent / "_tmp_train" + download_object(client, train_obj, tmp_train) + rows_train = sample_tfrecord(tmp_train, args.train_rows) + try: + tmp_train.unlink() + except OSError: + pass + + tmp_val = out_dir.parent / "_tmp_val" + download_object(client, val_obj, tmp_val) + rows_val = sample_tfrecord(tmp_val, args.val_rows) + try: + tmp_val.unlink() + except OSError: + pass + + print(f"Rows: train={len(rows_train)} val={len(rows_val)}") + write_parquet(train_out, rows_train, force=args.force) + write_parquet(val_out, rows_val, force=args.force) + + +if __name__ == "__main__": + main() diff --git a/pytest.ini b/pytest.ini index 1b7f0c635..c851cf939 100644 --- a/pytest.ini +++ b/pytest.ini @@ -16,7 +16,13 @@ markers = tpu_only: marks tests to be run on TPUs only gpu_only: marks tests to be run on GPUs only cpu_only: marks tests to be run on CPUs only + decoupled: tests that validate offline / DECOUPLE_GCLOUD=TRUE mode. + NOTE: this marker is not to be used manually, it is auto- + applied to tests with external_* or tpu_only marker. scheduled_only: marks tests to run only on scheduled CI runs integration_test: tests exercising larger portions of the system, including interactions with other systems like GCS, e.g., end_to_end tests + external_serving: JetStream / serving / decode server components + external_training: goodput integrations + diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 5be8df263..8310d766e 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -256,7 +256,7 @@ pipeline_delay_activation_forwarding: False # This delays the activation forward # and you must set the number of microbatches to at least 2 * num_stages (the minimum 2 * num_stages is set by default with this delay). model_fsdp_ag_once: False # This controls whether the Zero-1 optimization is active. -# This is a memory/time tradeoff - True: This is Zero-1 Sharding. Use ZeroOneTransformer to gather weights once per gradient step. +# This is a memory/time tradeoff - True: This is Zero-1 Sharding. Use ZeroOneTransformer to gather weights once per gradient step. # False: This is Zero-3 Sharing. Use the standard Transformer, which gathers for each microbatch's fwd/bwd pass. pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights over FSDP before the first pipeline iteration. # This is a memory/time tradeoff - we now have to store the FSDP gathered weights and gradients (typically in bf16), as opposed @@ -306,7 +306,7 @@ param_scan_axis: 1 # The attention_type parameter determines the variants of attention, e.g. global or local_sliding attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te attention_type: 'global' # Supported attention_type: global, local_sliding, chunk, mla -attention_bias: False # If True, adds a learnable bias to the query, key, and value projections +attention_bias: False # If True, adds a learnable bias to the query, key, and value projections attention_sink: False sliding_window_size: 0 chunk_attn_window_size: 0 @@ -424,7 +424,7 @@ logical_axis_rules: [ ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']], ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], ['embed_no_exp', ['fsdp', 'sequence', 'context']], - ['embed_tensor_transpose', ['tensor_transpose']], + ['embed_tensor_transpose', ['tensor_transpose']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], @@ -530,7 +530,7 @@ per_device_batch_size: 12.0 # Each data-loading host will load per_device_batch_size * expansion_factor_real_data. # When set to between 0 and 1, it's for grain pipeline to use a smaller chip count to read checkpoint from a larger chip count job. # Details in https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_grain.md#using-grain -expansion_factor_real_data: -1.0 +expansion_factor_real_data: -1.0 eval_per_device_batch_size: 0.0 max_corpus_chars: 10_000_000 train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" @@ -595,14 +595,15 @@ grain_train_files: '' grain_eval_files: '' grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data. grain_file_type: 'arrayrecord' # arrayrecord or parquet -grain_worker_count: 1 +grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html grain_per_worker_buffer_size: 1 # num_threads and prefetch_buffer_size are per-worker per-dataset. Used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions) # The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage. -grain_num_threads: 16 +grain_num_threads: 16 grain_prefetch_buffer_size: 500 grain_worker_count_eval: 1 grain_per_worker_buffer_size_eval: 1 +grain_ram_budget_mb: 1024 # RAM budget (MB) for auto-tuning worker count. Only used when grain_worker_count is -1. grain_num_threads_eval: 16 grain_prefetch_buffer_size_eval: 500 grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources. @@ -930,7 +931,7 @@ temporal_patch_size_for_vit: 2 num_position_embeddings_for_vit: 1024 deepstack_visual_indexes_for_vit: [] -# Subslice shape in the form of "x,y,z" when using pathways (single controller). +# Subslice shape in the form of "x,y,z" when using pathways (single controller). # Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium. subslice_shape: "" diff --git a/src/MaxText/configs/decoupled_base_test.yml b/src/MaxText/configs/decoupled_base_test.yml new file mode 100644 index 000000000..650d09e30 --- /dev/null +++ b/src/MaxText/configs/decoupled_base_test.yml @@ -0,0 +1,79 @@ +# Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml. +# Inherit all model defaults from base.yml but override any cloud-coupled paths and disable optional cloud features. +base_config: base.yml + +# Output goes to a local relative directory so tests do not require GCS. +base_output_directory: ./maxtext_local_output +run_name: test_decoupled + +# Disable checkpointing by default for speed unless a test explicitly enables it. +enable_checkpointing: false +vertex_tensorboard_project: "" +use_vertex_tensorboard: false +vertex_tensorboard_region: "" + +# Minimize batch/steps to reduce compilation/test time. +per_device_batch_size: 1 +steps: 2 +learning_rate_schedule_steps: 2 + +# Disable profiler to avoid extra setup. +profiler: "" +profile_periodically_period: 0 +profiler_steps: 0 + +# Leave dataset-related keys to be overridden by individual tests. +dataset_type: "" + +# Use dot_product attention to avoid GPU Pallas shared memory limits on AMD GPUs +attention: "dot_product" + +# Avoid HLO dump overhead. +dump_hlo: false +jax_cache_dir: "" + +# Neutral parallelism (single device) for local tests. +ici_data_parallelism: 1 +ici_tensor_parallelism: 1 +ici_pipeline_parallelism: 1 +ici_expert_parallelism: 1 +ici_sequence_parallelism: 1 +ici_context_parallelism: 1 +ici_tensor_transpose_parallelism: 1 +ici_tensor_sequence_parallelism: 1 +ici_autoregressive_parallelism: 1 +ici_fsdp_parallelism: 1 +ici_fsdp_transpose_parallelism: 1 + +# DCN dimensions to 1 (no multi-slice expectation locally). +dcn_data_parallelism: 1 +dcn_tensor_parallelism: 1 +dcn_pipeline_parallelism: 1 +dcn_expert_parallelism: 1 +dcn_sequence_parallelism: 1 +dcn_context_parallelism: 1 +dcn_tensor_transpose_parallelism: 1 +dcn_tensor_sequence_parallelism: 1 +dcn_autoregressive_parallelism: 1 +dcn_fsdp_parallelism: 1 +dcn_fsdp_transpose_parallelism: 1 + +# Config logging off unless a test overrides. +log_config: false + +# Explicitly disable all Goodput / GCP metric integrations. +enable_goodput_recording: false +monitor_goodput: false +goodput_upload_interval_seconds: 0 +enable_pathways_goodput: false +enable_gcp_goodput_metrics: false + +# Disable any cloud logging / BigQuery or external metric uploads. +enable_cloud_logging: false +upload_metrics_to_bigquery: false +bigquery_project: "" +bigquery_dataset: "" +bigquery_table: "" + +# Force local-only behavior for tests: avoid accidental env pickup. +tensorboard_dir: "./maxtext_local_output/tensorboard" diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 4f2406534..861201c03 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -860,6 +860,7 @@ class GrainDataset(BaseModel): grain_per_worker_buffer_size_eval: int = Field( 1, description="Buffer size for each worker for Grain data loading during evaluation." ) + grain_ram_budget_mb: int = Field(1024, description="RAM budget (MB) for auto-tuning worker count.") grain_num_threads: int = Field(16, description="Number of threads for Grain ReadOptions during training.") grain_prefetch_buffer_size: int = Field(500, description="Prefetch buffer size for Grain ReadOptions during training.") grain_num_threads_eval: int = Field(16, description="Number of threads for Grain ReadOptions during evaluation.") diff --git a/src/MaxText/gcloud_stub.py b/src/MaxText/gcloud_stub.py new file mode 100644 index 000000000..852eeedd4 --- /dev/null +++ b/src/MaxText/gcloud_stub.py @@ -0,0 +1,535 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Centralized decoupling helpers. + +Set DECOUPLE_GCLOUD=TRUE in the environment to disable optional Google Cloud / JetStream / GCS / diagnostics +integrations while still allowing local unit tests to import modules. This module provides: + +- is_decoupled(): returns True if decoupled flag set. +- cloud_diagnostics(): tuple(diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration) + providing either real objects or lightweight stubs. +- jetstream(): returns a namespace-like object exposing Engine, Devices, ResultTokens etc. or stubs. +- gcs_storage(): returns google.cloud.storage module or stub namespace with Client/Blob/Bucket. +- goodput_modules(): returns (goodput, monitoring, is_stub) for ml_goodput_measurement integration or stubs. +- monitoring_modules(): returns (monitoring_v3, metric_pb2, monitored_resource_pb2, GoogleAPIError, is_stub) + for Google Cloud Monitoring integration or stubs. + +All stubs raise RuntimeError only when actually invoked, not at import time, so test collection proceeds. +""" +from __future__ import annotations + +from types import SimpleNamespace +import importlib.util +import os + + +def is_decoupled() -> bool: # dynamic check so setting env after initial import still works + """Return True when DECOUPLE_GCLOUD environment variable is set to TRUE.""" + return os.environ.get("DECOUPLE_GCLOUD", "").upper() == "TRUE" + + +# ---------------- Cloud Diagnostics ----------------- + + +def _cloud_diag_stubs(): + """Return lightweight stubs for cloud TPU diagnostics.""" + import contextlib # pylint: disable=import-outside-toplevel + + class _StubDiag: + """Stub diagnostic object returning skip metadata.""" + + def run(self, *_a, **_k): + return {"status": "skipped"} + + def diagnose(self, *_a, **_k): + """Return a context manager that swallows diagnostic errors in stub mode.""" + + @contextlib.contextmanager + def _graceful_diagnose(): + try: + yield + except Exception as exc: # pylint: disable=broad-exception-caught + print("Warning: using stubs for cloud_diagnostics diagnose() - " f"caught: {exc}") + + return _graceful_diagnose() + + class _StubDebugConfig: + """Stub debug configuration.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + pass + + class _StubStackTraceConfig: + """Stub stack trace configuration.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + pass + + class _StubDiagnosticConfig: + """Stub diagnostic configuration wrapper.""" + + def __init__(self, *a, debug_config=None, **k): # pylint: disable=unused-argument + del a, k + self.debug_config = debug_config + + return ( + _StubDiag(), + SimpleNamespace(DebugConfig=_StubDebugConfig, StackTraceConfig=_StubStackTraceConfig), + SimpleNamespace(DiagnosticConfig=_StubDiagnosticConfig), + SimpleNamespace(StackTraceConfig=_StubStackTraceConfig), + ) + + +def cloud_diagnostics(): + """Return real cloud diagnostics modules or stubs. + + If a dependency is missing and we are decoupled, return stubs. Otherwise + re-raise the import error so callers fail fast. + """ + try: + from cloud_tpu_diagnostics import diagnostic # type: ignore # pylint: disable=import-outside-toplevel + from cloud_tpu_diagnostics.configuration import ( # type: ignore # pylint: disable=import-outside-toplevel + debug_configuration, + diagnostic_configuration, + stack_trace_configuration, + ) + + return diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration + except ModuleNotFoundError: + if is_decoupled(): + print("[DECOUPLED NO-OP] cloud_diagnostics: dependency missing; using stubs.") + return _cloud_diag_stubs() + raise + + +# ---------------- JetStream ----------------- + + +def _jetstream_stubs(): + """Return lightweight stubs for JetStream modules.""" + + class Engine: # minimal base class stub + """Stub Engine accepting any initialization signature.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + pass + + class ResultTokens: + """Container for result token arrays used by JetStream.""" + + def __init__( + self, + *args, + data=None, + tokens_idx=None, + valid_idx=None, + length_idx=None, + log_prob=None, + samples_per_slot: int | None = None, + **kwargs, + ): + del args, kwargs # unused + self.data = data + self.tokens_idx = tokens_idx + self.valid_idx = valid_idx + self.length_idx = length_idx + self.log_prob = log_prob + self.samples_per_slot = samples_per_slot + + # Tokenizer placeholders (unused in decoupled tests due to runtime guard). + class TokenizerParameters: # pragma: no cover - placeholder + """Stub tokenizer parameters object.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + pass + + class TokenizerType: # emulate enum descriptor access pattern + """Stub tokenizer type descriptor container.""" + + DESCRIPTOR = SimpleNamespace(values_by_name={}) + + config_lib = SimpleNamespace() # not used directly in decoupled tests + engine_api = SimpleNamespace(Engine=Engine, ResultTokens=ResultTokens) + token_utils = SimpleNamespace() # build_tokenizer guarded in MaxEngine when decoupled + tokenizer_api = SimpleNamespace() # placeholder + token_params_ns = SimpleNamespace(TokenizerParameters=TokenizerParameters, TokenizerType=TokenizerType) + return config_lib, engine_api, token_utils, tokenizer_api, token_params_ns + + +def jetstream(): + """Return JetStream modules or stubs based on availability and decoupling.""" + needed = [ + "jetstream.core.config_lib", + "jetstream.engine.engine_api", + "jetstream.engine.token_utils", + "jetstream.engine.tokenizer_api", + "jetstream.engine.tokenizer_pb2", + ] + try: + for mod in needed: + if importlib.util.find_spec(mod) is None: + if is_decoupled(): + print("[DECOUPLED NO-OP] jetstream: dependency missing; using stubs.") + return _jetstream_stubs() + raise ModuleNotFoundError(mod) + + from jetstream.core import config_lib # type: ignore # pylint: disable=import-outside-toplevel + from jetstream.engine import engine_api, token_utils, tokenizer_api # type: ignore # pylint: disable=import-outside-toplevel + from jetstream.engine.tokenizer_pb2 import TokenizerParameters, TokenizerType # type: ignore # pylint: disable=import-outside-toplevel + + return ( + config_lib, + engine_api, + token_utils, + tokenizer_api, + SimpleNamespace(TokenizerParameters=TokenizerParameters, TokenizerType=TokenizerType), + ) + except ModuleNotFoundError: + if is_decoupled(): + print("[DECOUPLED NO-OP] jetstream: dependency missing; using stubs.") + return _jetstream_stubs() + raise + + +# ---------------- GCS ----------------- + + +def _gcs_stubs(): # pragma: no cover - simple no-op placeholders + """Return stub implementations of the google.cloud.storage API.""" + + class _StubBlob: + """Stub GCS blob with no-op operations.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + del a, k + + def upload_from_filename(self, *a, **k): # pylint: disable=unused-argument + return False + + def upload_from_string(self, *a, **k): # pylint: disable=unused-argument + return False + + def exists(self): + return False + + def download_as_string(self): + return b"{}" + + class _StubListPages: + """Stub for iterable pages returned by list_blobs.""" + + def __init__(self): + self.pages = [SimpleNamespace(prefixes=[])] + + class _StubBucket: + """Stub GCS bucket returning stub blobs and pages.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + del a, k + + def blob(self, *a, **k): # pylint: disable=unused-argument + return _StubBlob() + + def list_blobs(self, *a, **k): # pylint: disable=unused-argument + return _StubListPages() + + class _StubClient: + """Stub GCS client exposing bucket helpers.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + del a, k + + def get_bucket(self, *a, **k): # pylint: disable=unused-argument + return _StubBucket() + + def bucket(self, *a, **k): # pylint: disable=unused-argument + return _StubBucket() + + return SimpleNamespace(Client=_StubClient, _IS_STUB=True) + + +def gcs_storage(): + """Return google.cloud.storage module or stub when decoupled or missing.""" + # In decoupled mode always prefer the stub, even if the library is installed, + # to avoid accidental GCS calls in tests or local runs. + if is_decoupled(): # fast path + print("[DECOUPLED NO-OP] gcs_storage: dependency missing; using stubs.") + return _gcs_stubs() + + try: # pragma: no cover - attempt real import when not decoupled + from google.cloud import storage # type: ignore # pylint: disable=import-outside-toplevel + + setattr(storage, "_IS_STUB", False) + return storage + except Exception: # ModuleNotFoundError / ImportError for partial installs # pylint: disable=broad-exception-caught + print("[DECOUPLED NO-OP] gcs_storage: dependency missing; using stubs.") + return _gcs_stubs() + + +# ---------------- Goodput (ml_goodput_measurement) ----------------- + + +def _goodput_stubs(): + """Return stubs for ml_goodput_measurement integration.""" + + class _StubGoodputRecorder: + """Recorder stub exposing no-op methods and disabled flag.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + self.enabled = False + + def __getattr__(self, name): + def _noop(*_a, **_k): + pass + + return _noop + + class _StubMonitoringOptions: + """Stub monitoring options container.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + pass + + class _StubGoodputMonitor: + """Stub goodput monitor with no-op uploader methods.""" + + def __init__(self, *a, **_k): # pylint: disable=unused-argument + pass + + def start_goodput_uploader(self): + print("[DECOUPLED NO-OP] goodput uploader skipped.") + + def start_step_deviation_uploader(self): + print("[DECOUPLED NO-OP] goodput step deviation uploader skipped.") + + monitoring_ns = SimpleNamespace(GCPOptions=_StubMonitoringOptions, GoodputMonitor=_StubGoodputMonitor) + goodput_ns = SimpleNamespace(GoodputRecorder=_StubGoodputRecorder) + return goodput_ns, monitoring_ns, True + + +def goodput_modules(): + """Return real goodput modules or stubs when missing and decoupled.""" + try: + from ml_goodput_measurement import goodput, monitoring # type: ignore # pylint: disable=import-outside-toplevel + + return goodput, monitoring, False + except ModuleNotFoundError: + if is_decoupled(): + print("[DECOUPLED NO-OP] ml_goodput_measurement: dependency missing; using stubs.") + return _goodput_stubs() + raise + + +__all__ = ["is_decoupled", "cloud_diagnostics", "jetstream", "gcs_storage", "goodput_modules"] + +# ---------------- Cloud Monitoring (monitoring_v3 / metric_pb2) ----------------- + + +def _monitoring_stubs(): # pragma: no cover - simple placeholders + """Return stub implementations for Cloud Monitoring APIs.""" + + class GoogleAPIError(Exception): + """Stub GoogleAPIError mirroring the real exception name.""" + + class _DummyMonitoringV3: + """Dummy monitoring module providing minimal types.""" + + class TimeSeries: + + def __init__(self, *a, **k): # pylint: disable=unused-argument + del a, k + + class Point: + + def __init__(self, *a, **k): # pylint: disable=unused-argument + del a, k + + class TimeInterval: + + def __init__(self, *a, **k): # pylint: disable=unused-argument + del a, k + + class TypedValue: + + def __init__(self, *a, **k): # pylint: disable=unused-argument + del a, k + + class MetricServiceClient: + + def __init__(self, *a, **k): # pylint: disable=unused-argument + del a, k + + def create_time_series(self, *a, **k): # pylint: disable=unused-argument + return False + + class _DummyMetricPB2: + """Dummy metric_pb2 module namespace.""" + + class Metric: + + def __init__(self, *a, **k): # pylint: disable=unused-argument + del a, k + + class _DummyMonitoredResourcePB2: + """Dummy monitored_resource_pb2 module namespace.""" + + class MonitoredResource: + + def __init__(self, *a, **k): # pylint: disable=unused-argument + del a, k + + return _DummyMonitoringV3(), _DummyMetricPB2(), _DummyMonitoredResourcePB2(), GoogleAPIError, True + + +def monitoring_modules(): + """Return monitoring modules or stubs. + + Stubs only if decoupled AND dependency missing; if not decoupled and missing -> + re-raise. + """ + try: # Attempt real imports first + from google.cloud import monitoring_v3 # type: ignore # pylint: disable=import-outside-toplevel + from google.api import metric_pb2, monitored_resource_pb2 # type: ignore # pylint: disable=import-outside-toplevel + from google.api_core.exceptions import GoogleAPIError # type: ignore # pylint: disable=import-outside-toplevel + + return monitoring_v3, metric_pb2, monitored_resource_pb2, GoogleAPIError, False + except (ModuleNotFoundError, ImportError): # broaden to handle partial google installs + if is_decoupled(): + print("[DECOUPLED NO-OP] monitoring: dependency missing; using stubs.") + return _monitoring_stubs() + raise + + +__all__.append("monitoring_modules") + +# ---------------- Workload Monitor (GCPWorkloadMonitor) ----------------- + + +def _workload_monitor_stub(): # pragma: no cover - simple placeholder + """Return stub GCPWorkloadMonitor implementation and stub flag.""" + + class GCPWorkloadMonitor: + """Stub of GCPWorkloadMonitor exposing no-op methods.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + pass + + def start_heartbeat_reporting_thread(self, *a, **k): # pylint: disable=unused-argument + pass + + def start_performance_reporting_thread(self, *a, **k): # pylint: disable=unused-argument + pass + + return GCPWorkloadMonitor, True + + +def workload_monitor(): + """Return (GCPWorkloadMonitor, is_stub) centralizing stub logic. + + If decoupled OR import fails, returns stub class; otherwise real class. + """ + if is_decoupled(): # fast path: never attempt heavy import + print("[DECOUPLED NO-OP] workload_monitor: using stub.") + return _workload_monitor_stub() + + try: + from MaxText.gcp_workload_monitor import GCPWorkloadMonitor # type: ignore # pylint: disable=import-outside-toplevel + + return GCPWorkloadMonitor, False + except Exception: # ModuleNotFoundError / ImportError # pylint: disable=broad-exception-caught + print("[NO-OP] workload_monitor dependency missing; using stub.") + return _workload_monitor_stub() + + +__all__.append("workload_monitor") + +# ---------------- Vertex Tensorboard ----------------- + + +def _vertex_tb_stub(): # pragma: no cover - simple placeholder + """Return stub VertexTensorboardManager implementation and stub flag.""" + + class VertexTensorboardManager: + """Stub VertexTensorboardManager with no-op configure method.""" + + def __init__(self, *a, **k): # pylint: disable=unused-argument + pass + + def configure_vertex_tensorboard(self, *a, **k): # pylint: disable=unused-argument + # NO-OP in decoupled / missing dependency mode + pass + + return VertexTensorboardManager, True + + +def vertex_tensorboard_components(): + """Return (VertexTensorboardManager, is_stub). + + Decoupled or missing dependency -> stub class with no-op configure method. + """ + if is_decoupled(): + print("[DECOUPLED NO-OP] vertex_tensorboard: using stub.") + return _vertex_tb_stub() + + try: + from MaxText.vertex_tensorboard import VertexTensorboardManager # type: ignore # pylint: disable=import-outside-toplevel + + return VertexTensorboardManager, False + except Exception: # pylint: disable=broad-exception-caught + print("[NO-OP] vertex_tensorboard dependency missing; using stub.") + return _vertex_tb_stub() + + +__all__.append("vertex_tensorboard_components") + +# ---------------- TensorBoardX (moved stub) ----------------- + +try: + if not is_decoupled(): # Only attempt real import when not decoupled + from tensorboardX import writer # type: ignore # pylint: disable=import-outside-toplevel,unused-import + + _TENSORBOARDX_AVAILABLE = True + else: + raise ModuleNotFoundError("Decoupled mode skips tensorboardX import") +except Exception: # pragma: no cover - provide stub fallback # pylint: disable=broad-exception-caught + _TENSORBOARDX_AVAILABLE = False + + class _DummySummaryWriter: + """Stubbed TensorBoardX SummaryWriter replacement.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + del args, kwargs + + def add_text(self, *args, **kwargs): + pass + + def add_scalar(self, *args, **kwargs): + pass + + def add_histogram(self, *args, **kwargs): + pass + + def flush(self): + pass + + def close(self): + pass + + class writer: # pylint: disable=too-few-public-methods + SummaryWriter = _DummySummaryWriter + + +__all__.append("writer") +__all__.append("_TENSORBOARDX_AVAILABLE") diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/MaxText/input_pipeline/_grain_data_processing.py index 74c72046a..809de5629 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/MaxText/input_pipeline/_grain_data_processing.py @@ -23,6 +23,7 @@ import jax +from grain.experimental import pick_performance_config import grain.python as grain from MaxText.utils import gcs_utils @@ -230,12 +231,20 @@ def pretrain_preprocessing_pipeline( axis=1, ) ) - dataset = dataset.mp_prefetch( - grain.MultiprocessingOptions( + multiprocessing_options = ( + pick_performance_config( + ds=dataset, + ram_budget_mb=config.grain_ram_budget_mb, + max_workers=None, + max_buffer_size=None, + ).multiprocessing_options + if grain_worker_count == -1 + else grain.MultiprocessingOptions( num_workers=grain_worker_count, per_worker_buffer_size=grain_per_worker_buffer_size, ) ) + dataset = dataset.mp_prefetch(multiprocessing_options) return dataset @@ -273,12 +282,20 @@ def dpo_preprocessing_pipeline( batch_size = config.global_batch_size_to_load // jax.process_count() batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) dataset = dataset.batch(batch_size, batch_fn=batch_fn) - dataset = dataset.mp_prefetch( - grain.MultiprocessingOptions( + multiprocessing_options = ( + pick_performance_config( + ds=dataset, + ram_budget_mb=config.grain_ram_budget_mb, + max_workers=None, + max_buffer_size=None, + ).multiprocessing_options + if grain_worker_count == -1 + else grain.MultiprocessingOptions( num_workers=grain_worker_count, per_worker_buffer_size=grain_per_worker_buffer_size, ) ) + dataset = dataset.mp_prefetch(multiprocessing_options) return dataset diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index aba50a8d6..1f6473a08 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -412,7 +412,7 @@ def get_decoder_layers(self): case DecoderBlockType.GEMMA3: return [gemma3.Gemma3DecoderLayerToLinen] case DecoderBlockType.GPT3: - return [gpt3.Gpt3DecoderLayer] + return [gpt3.Gpt3DecoderLayerToLinen] case DecoderBlockType.GPT_OSS: return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen] case DecoderBlockType.QWEN3: @@ -598,7 +598,7 @@ def _apply_embedding( name="position_embedder", config=cfg, mesh=self.mesh, - )(decoder_positions, model_mode=model_mode) + )(decoder_positions.astype("int32"), model_mode=model_mode) return y @nn.compact diff --git a/src/MaxText/layers/gpt3.py b/src/MaxText/layers/gpt3.py index 831677583..1bec2bb26 100644 --- a/src/MaxText/layers/gpt3.py +++ b/src/MaxText/layers/gpt3.py @@ -16,7 +16,7 @@ # pylint: disable=arguments-differ # pylint: disable=no-name-in-module -from typing import Any +from typing import Any, Optional import jax from jax import lax @@ -31,12 +31,12 @@ from MaxText import max_utils from MaxText.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN from MaxText.layers import initializers, nnx_wrappers -from MaxText.layers.linears import mlp_block +from MaxText.layers.linears import DenseGeneral, MlpBlock, canonicalize_tuple, normalize_axes from MaxText.layers import models from MaxText.layers import quantizations -from MaxText.layers.attention_op import KVQuant, attention_op_as_linen +from MaxText.layers import linears +from MaxText.layers.attentions import AttentionOp, KVQuant from MaxText.layers.initializers import Initializer, NdInitializer, nd_dense_init -from MaxText.layers.linears import dense_general from MaxText.layers.quantizations import AqtQuantization as Quant # ----------------------------------------- @@ -163,7 +163,7 @@ def gpt3_layer_norm( # ----------------------------------------- -class Gpt3MultiHeadAttention(nn.Module): +class Gpt3MultiHeadAttention(nnx.Module): """Multi-head attention in gpt3. Attributes: @@ -185,102 +185,138 @@ class Gpt3MultiHeadAttention(nn.Module): use_bias: whether to add bias in linear transformation. """ - config: Config - num_heads: int - head_dim: int - max_target_length: int - max_prefill_predict_length: int - mesh: Mesh - attention_kernel: str - dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 - dropout_rate: float = 0.0 - kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") - float32_qk_product: bool = False # computes logits in float32 for stability. - float32_logits: bool = True # cast logits in float32 for stability. - fused_qkv: bool = True - quant: None | Quant = None - kv_quant: None | KVQuant = None - use_bias: bool = True - - input_axis_names: AxisNames = (BATCH, LENGTH, EMBED) - query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) - key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) - value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) - out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) - - def qkv_projection(self, inputs: Array, proj_name: str): - """Fused QKV projection""" + def __init__( + self, + config: Config, + model_mode: str, + num_heads: int, + feature_dim: tuple[int, ...], + head_dim: int, + max_target_length: int, + max_prefill_predict_length: int, + mesh: Mesh, + rngs: nnx.Rngs, + attention_kernel: str, + dtype: DType = jnp.float32, + weight_dtype: DType = jnp.float32, + dropout_rate: float = 0.0, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"), + float32_qk_product: bool = False, # computes logits in float32 for stability. + float32_logits: bool = True, # cast logits in float32 for stability. + fused_qkv: bool = True, + quant: Optional[Quant] = None, + kv_quant: Optional[KVQuant] = None, + use_bias: bool = True, + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), + **kwargs: Any, + ): + self.config = config + self.num_heads = num_heads + self.head_dim = head_dim + self.max_target_length = max_target_length + self.max_prefill_predict_length = max_prefill_predict_length + self.mesh = mesh + self.attention_kernel = attention_kernel + self.dtype = dtype + self.weight_dtype = weight_dtype + self.dropout_rate = dropout_rate + self.kernel_init = kernel_init + self.float32_qk_product = float32_qk_product + self.float32_logits = float32_logits + self.fused_qkv = fused_qkv + self.quant = quant + self.kv_quant = kv_quant + self.use_bias = use_bias + self.input_axis_names = input_axis_names + self.query_axis_names = query_axis_names + self.key_axis_names = key_axis_names + self.value_axis_names = value_axis_names + self.out_axis_names = out_axis_names + self.rngs = rngs + if self.fused_qkv: + self.qkv_proj = self.create_projection_layer( + feature_dim, (3, self.num_heads, self.head_dim), ("embed", "qkv", "heads", "kv") + ) + else: + self.query = self.create_projection_layer(feature_dim, (self.num_heads, self.head_dim), ("embed", "heads", "kv")) + self.key = self.create_projection_layer(feature_dim, (self.num_heads, self.head_dim), ("embed", "heads", "kv")) + self.value = self.create_projection_layer(feature_dim, (self.num_heads, self.head_dim), ("embed", "heads", "kv")) + self.out = self.create_projection_layer( + (self.num_heads, self.head_dim), self.num_heads * self.head_dim, ("heads", "kv", "embed"), axis=(-2, -1) + ) + self.attention_op = AttentionOp( + config=config, + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + kv_quant=self.kv_quant, + num_query_heads=self.num_heads, + num_kv_heads=self.num_heads, + dtype=self.dtype, + ) - qkv_proj = dense_general( - inputs_shape=inputs.shape, - out_features_shape=(3, self.num_heads, self.head_dim), - axis=-1, + def create_projection_layer( + self, + input_shape: tuple[int, ...], + output_shape: tuple[int, ...] | int, + kernel_axes: tuple[str, ...], + axis: int | tuple[int, ...] = -1, + ): + """Create projection layer for Key, Value, Query and Output""" + axis = canonicalize_tuple(axis) + in_features_shape = tuple(input_shape[ax] for ax in normalize_axes(axis, len(input_shape))) + + return DenseGeneral( + in_features_shape=in_features_shape, + out_features_shape=output_shape, + axis=axis, kernel_init=self.kernel_init, - kernel_axes=("embed", "qkv", "heads", "kv"), + kernel_axes=kernel_axes, dtype=self.dtype, weight_dtype=self.weight_dtype, - name=proj_name, quant=self.quant, use_bias=self.use_bias, matmul_precision=self.config.matmul_precision, - )(inputs) + rngs=self.rngs, + ) + + def qkv_projection(self, projection_layer: Any, inputs: Array): + """Fused QKV projection""" + qkv_proj = projection_layer(inputs) + qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...] return query, key, value - def projection(self, inputs: Array, proj_name: str) -> Array: + def projection(self, projection_layer: Any, inputs: Array) -> Array: """individual projection for one of q, k and v.""" - proj = dense_general( - inputs_shape=inputs.shape, - out_features_shape=(self.num_heads, self.head_dim), - axis=-1, - kernel_init=self.kernel_init, - kernel_axes=("embed", "heads", "kv"), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name=proj_name, - quant=self.quant, - use_bias=self.use_bias, - matmul_precision=self.config.matmul_precision, - )(inputs) + proj = projection_layer(inputs) return proj - def out_projection(self, output_dim: int, out: Array) -> Array: - """output projection""" - out_proj = dense_general( - inputs_shape=out.shape, - out_features_shape=output_dim, - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=("heads", "kv", "embed"), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name="out", - quant=self.quant, - use_bias=self.use_bias, - matmul_precision=self.config.matmul_precision, - )(out) - return out_proj - - @nn.compact def __call__( self, inputs_q: Array, decoder_segment_ids: Array | None = None, *, - model_mode: str = MODEL_MODE_TRAIN, deterministic: bool = False, + model_mode: str = MODEL_MODE_TRAIN, kv_cache: Array | None = None, attention_metadata: dict[str, Any] | None = None, ): inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names) if self.fused_qkv: - query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") + query, key, value = self.qkv_projection(self.qkv_proj, inputs_q) else: - query = self.projection(inputs_q, proj_name="query") - key = self.projection(inputs_q, proj_name="key") - value = self.projection(inputs_q, proj_name="value") + query = self.projection(self.query, inputs_q) + key = self.projection(self.key, inputs_q) + value = self.projection(self.value, inputs_q) depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) query /= depth_scaling @@ -293,26 +329,12 @@ def __call__( value = nn.with_logical_constraint(value, self.value_axis_names) value = checkpoint_name(value, "value_proj") - attention_op = attention_op_as_linen( - config=self.config, - mesh=self.mesh, - attention_kernel=self.attention_kernel, - max_target_length=self.max_target_length, - float32_qk_product=self.float32_qk_product, - float32_logits=self.float32_logits, - quant=self.quant, - kv_quant=self.kv_quant, - num_query_heads=self.num_heads, - num_kv_heads=self.num_heads, - dtype=self.dtype, - ) - - out = attention_op(query, key, value, decoder_segment_ids, model_mode) + out = self.attention_op(query, key, value, decoder_segment_ids, model_mode) out = nn.with_logical_constraint(out, self.out_axis_names) # apply output projection, output dim is set to the input dim. - out = self.out_projection(inputs_q.shape[-1], out) + out = self.projection(self.out, out) out = checkpoint_name(out, "out_proj") return out, kv_cache @@ -322,15 +344,79 @@ def __call__( # ----------------------------------------- -class Gpt3DecoderLayer(nn.Module): +class Gpt3DecoderLayer(nnx.Module): """Transformer decoder layer that attends to the encoder.""" - config: models.Config - mesh: Mesh - model_mode: str - quant: None | Quant = None + def __init__( + self, + config: models.Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[Quant] = None, + ): + + self.config = config + self.mesh = mesh + self.quant = quant + self.rngs = rngs + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_norm = Gpt3LayerNorm( + num_features=dummy_inputs_shape[-1], + dtype=config.dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + reductions_in_fp32=False, + use_bias=True, + rngs=self.rngs, + ) + + self.mlp = MlpBlock( + mesh=self.mesh, + in_features=dummy_inputs_shape[-1], + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + use_bias=True, + use_pre_norm=True, + config=config, + quant=self.quant, + model_mode=model_mode, + rngs=self.rngs, + ) + + self.self_attention = Gpt3MultiHeadAttention( + config=config, + num_heads=config.num_query_heads, + dtype=config.dtype, + feature_dim=dummy_inputs_shape, + weight_dtype=config.weight_dtype, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + mesh=self.mesh, + dropout_rate=config.dropout_rate, + name="self_attention", + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + fused_qkv=config.fused_qkv, + use_bias=True, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + model_mode=model_mode, + rngs=self.rngs, + ) + + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") - @nn.compact def __call__( self, inputs, @@ -344,48 +430,18 @@ def __call__( kv_cache=None, attention_metadata=None, ): - cfg = self.config - mesh = self.mesh - - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") - lnx = gpt3_layer_norm( - num_features=inputs.shape[-1], - dtype=cfg.dtype, - name="pre_self_attention_norm", - kernel_axes=("norm",), - epsilon=cfg.normalization_layer_epsilon, - reductions_in_fp32=False, - use_bias=True, - )(inputs) + lnx = self.pre_self_attention_norm(inputs) - lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) + lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) # Self-attention block assert ( - cfg.num_query_heads == cfg.num_kv_heads - ), f"{cfg.num_query_heads=} should be the same as {cfg.num_kv_heads=} in gpt3" - attention_layer = Gpt3MultiHeadAttention( - config=cfg, - num_heads=cfg.num_query_heads, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dropout_rate=cfg.dropout_rate, - name="self_attention", - float32_qk_product=cfg.float32_qk_product, - float32_logits=cfg.float32_logits, - fused_qkv=cfg.fused_qkv, - use_bias=True, - quant=self.quant, - kv_quant=quantizations.configure_kv_quant(cfg), - ) + self.config.num_query_heads == self.config.num_kv_heads + ), f"{self.config.num_query_heads=} should be the same as {self.config.num_kv_heads=} in gpt3" - attention_lnx, kv_cache = attention_layer( + attention_lnx, kv_cache = self.self_attention( lnx, decoder_segment_ids=decoder_segment_ids, model_mode=model_mode, @@ -394,38 +450,17 @@ def __call__( attention_metadata=attention_metadata, ) - attention_lnx = nn.with_logical_constraint( - attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed") - ) + attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) attention_lnx += inputs - # MLP block. - mlp_lnx = mlp_block( - in_features=attention_lnx.shape[-1], - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="mlp", - use_bias=True, - use_pre_norm=True, - config=cfg, - quant=self.quant, - mesh=self.mesh, - )(attention_lnx, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) + mlp_lnx = self.mlp(attention_lnx, deterministic=deterministic) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) layer_output = attention_lnx + mlp_lnx + layer_output = self.dropout(layer_output, deterministic=deterministic) + layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) - - layer_output = nn.with_logical_constraint( - layer_output, - ("activation_batch", "activation_norm_length", "activation_embed"), - ) - - if cfg.record_internal_nn_metrics: + if self.config.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( @@ -434,7 +469,13 @@ def __call__( jnp.sum(layer_output == 0) / jnp.size(layer_output), ) - if cfg.scan_layers: + if self.config.scan_layers: return layer_output, None else: return layer_output, kv_cache + + +Gpt3DecoderLayerToLinen = nnx_wrappers.to_linen_class( + Gpt3DecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..523f3b171 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,105 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pytest configuration helpers for decoupled test selection. + +Automatically apply the `decoupled` marker (when DECOUPLE_GCLOUD=TRUE) to +tests that remain collected. Tests that are explicitly skipped because they +require external integrations or specific hardware (for example `tpu_only`) +are not marked. +""" + +import pytest +from MaxText.gcloud_stub import is_decoupled +import jax + +# Configure JAX to use unsafe_rbg PRNG implementation to match main scripts. +if is_decoupled(): + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + +try: + _HAS_TPU = any(d.platform == "tpu" for d in jax.devices()) +except Exception: # pragma: no cover pylint: disable=broad-exception-caught + _HAS_TPU = False + +try: + _HAS_GPU = any(d.platform == "gpu" for d in jax.devices()) +except Exception: # pragma: no cover pylint: disable=broad-exception-caught + _HAS_GPU = False + + +GCP_MARKERS = {"external_serving", "external_training"} + + +def pytest_collection_modifyitems(config, items): + """Customize pytest collection behavior. + + - Skip hardware-specific tests when hardware is missing. + - Deselect tests marked as external_serving/training in decoupled mode. + - Mark remaining tests with the `decoupled` marker when running decoupled. + """ + decoupled = is_decoupled() + remaining = [] + deselected = [] + + skip_no_tpu = None + skip_no_gpu = None + if not _HAS_TPU: + skip_no_tpu = pytest.mark.skip(reason="Skipped: requires TPU hardware, none detected") + + if not _HAS_GPU: + skip_no_gpu = pytest.mark.skip(reason="Skipped: requires GPU hardware, none detected") + + for item in items: + # Iterate thru the markers of every test. + cur_test_markers = {m.name for m in item.iter_markers()} + + # Hardware skip retains skip semantics. + if skip_no_tpu and "tpu_only" in cur_test_markers: + item.add_marker(skip_no_tpu) + remaining.append(item) + continue + + if skip_no_gpu and "gpu_only" in cur_test_markers: + item.add_marker(skip_no_gpu) + remaining.append(item) + continue + + if decoupled and (cur_test_markers & GCP_MARKERS): + # Deselect tests marked as external_serving/training entirely. + deselected.append(item) + continue + + remaining.append(item) + + # Update items in-place to only keep remaining tests. + items[:] = remaining + if deselected: + config.hook.pytest_deselected(items=deselected) + + # Add decoupled marker to all remaining tests when running decoupled. + if decoupled: + for item in remaining: + item.add_marker(pytest.mark.decoupled) + + +def pytest_configure(config): + for m in [ + "gpu_only: tests that require GPU hardware", + "tpu_only: tests that require TPU hardware", + "external_serving: JetStream / serving / decode server components", + "external_training: goodput integrations", + "decoupled: marked on tests that are not skipped due to GCP deps, when DECOUPLE_GCLOUD=TRUE", + ]: + config.addinivalue_line("markers", m) diff --git a/tests/grain_data_processing_test.py b/tests/grain_data_processing_test.py index b81c038cf..1ae759851 100644 --- a/tests/grain_data_processing_test.py +++ b/tests/grain_data_processing_test.py @@ -22,6 +22,7 @@ import json import jax +import pytest from jax.sharding import Mesh from jax.experimental import mesh_utils @@ -182,6 +183,49 @@ def setUp(self): self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) +class GrainArrayRecordAutoTuneTest(GrainArrayRecordProcessingTest): + """Test grain data processing with auto-tuning enabled (grain_worker_count=-1).""" + + def setUp(self): + super().setUp() + temp_dir = tempfile.gettempdir() + self.config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_type="grain", + grain_train_files=os.path.join( + temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*" + ), + grain_worker_count=-1, # Enable auto-tuning + tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"), + enable_checkpointing=False, + ) + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config.data_sharding, + self.config.global_batch_size_to_load, + self.config.global_batch_size_to_train_on, + self.config.max_target_length, + self.mesh, + ) + self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + + @pytest.mark.skip( + reason=( + "Auto-tuning tries multiple numbers of workers during the first few batches " + "and it affects batch determinism at first." + ) + ) + def test_batch_determinism(self): + super().test_batch_determinism() + + class GrainParquetProcessingTest(unittest.TestCase): @classmethod diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..560269758 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,38 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test utilities file for helper for test configuration path selection. + +Provides a single helper to return the absolute path to a test config. When +running in decoupled mode (DECOUPLE_GCLOUD=TRUE) the decoupled test config is +returned. +""" + +import os +from MaxText.gcloud_stub import is_decoupled +from MaxText.globals import MAXTEXT_PKG_DIR + + +def get_test_config_path(): + """Return absolute path to the chosen test config file. + + Returns `decoupled_base_test.yml` when decoupled, otherwise `base.yml`. + """ + base_cfg = "base.yml" + if is_decoupled(): + base_cfg = "decoupled_base_test.yml" + return os.path.join(MAXTEXT_PKG_DIR, "configs", base_cfg) + + +__all__ = ["get_test_config_path"]