Skip to content

Commit e05a187

Browse files
committed
remove unsafe_rbg from post training
1 parent acd9e07 commit e05a187

File tree

3 files changed

+5
-10
lines changed

3 files changed

+5
-10
lines changed

docs/tutorials/grpo.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ recommend creating the virtual environment outside the `maxtext` directory.
3131

3232
## vLLM and tpu-inference installations
3333

34+
### From PyPI releases
35+
3436
Next, run the following bash script to get all the necessary installations inside the virtual environment (for e.g., `maxtext_venv`).
3537
This will take few minutes. Follow along the installation logs and look out for any issues!
3638

@@ -40,6 +42,9 @@ bash ~/maxtext/src/MaxText/examples/install_tunix_vllm_requirement.sh
4042

4143
Primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support.
4244

45+
### From Github
46+
47+
You can also locally git clone [tunix](https://github.com/google/tunix) and install using the instructions [here](https://github.com/google/tunix?tab=readme-ov-file#installation). Similarly install [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) from source following the instructions [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source)
4348

4449
## Run GRPO
4550

src/MaxText/rl/train_rl.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,12 +460,7 @@ def main(argv: Sequence[str]) -> None:
460460
argv: Command-line arguments.
461461
"""
462462
pathwaysutils.initialize()
463-
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
464463
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
465-
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
466-
os.environ["LIBTPU_INIT_ARGS"] = (
467-
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
468-
)
469464

470465
max_utils.print_system_information()
471466
trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv)

src/MaxText/sft/sft_trainer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,7 @@ def main(argv: Sequence[str]) -> None:
184184
argv: Command-line arguments.
185185
"""
186186
pathwaysutils.initialize()
187-
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
188187
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
189-
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
190-
os.environ["LIBTPU_INIT_ARGS"] = (
191-
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
192-
)
193188

194189
mt_config = pyconfig.initialize(argv)
195190
max_utils.print_system_information()

0 commit comments

Comments
 (0)