Releases: google/tunix
Tunix v0.1.6 — Agentic RL & VLM
Highlights
- supports Agentic RL training, see https://github.com/google/tunix/tree/main/examples/agentic/gemma_grpo_demo_nb.py
- supports VLM training, see https://github.com/google/tunix/blob/main/examples/sft/vlm_training.py
from tunix import AgenticGRPOConfig
from tunix import AgenticGRPOLearner
agentic_grpo_config = AgenticGRPOConfig(
num_generations=NUM_GENERATIONS,
num_iterations=NUM_ITERATIONS,
max_response_length=MAX_RESPONSE_LENGTH,
beta=BETA,
epsilon=EPSILON,
system_prompt=SWE_SYSTEM_PROMPT,
max_concurrency=1,
epsilon_high=0.28,
off_policy_steps=0,
)
agentic_grpo_learner = AgenticGRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fns,
agent_class=MyAgentClass,
agent_kwargs={},
env_class=MyEnv,
env_kwargs={"max_steps": MAX_STEPS},
algo_config=agentic_grpo_config,
chat_parser=chat_parser,
)
agentic_grpo_learner.train(train_dataset=train_dataset)What's Changed
- Developing for v0.1.6 now. by @wang2yn84 in #785
- Fix the vllm server mode not finish issue. by @wang2yn84 in #784
- [Tunix] Update Dockerfile and deepscaler trainer script to seperate trainer model and ref model. by @copybara-service[bot] in #725
- Add Tunix RL GRPO examples for Gemma3. by @copybara-service[bot] in #788
- [Tunix] change model implementation to be pytree compatible. by @copybara-service[bot] in #782
- Fix TPU nightly regression workflow to use vLLM container and add new tests. by @copybara-service[bot] in #754
- [Tunix] Update sharding configuration for attention weights. by @copybara-service[bot] in #759
- [Tunix] Add gcsfs to TPU nightly regression dependencies. by @copybara-service[bot] in #790
- Adding back test_logprobs_extraction_with_missing_token. by @wang2yn84 in #789
- feat:add device indexes for sglang jax by @pathfinder-pf in #786
- Fix the rendering issue in Example gallery document. by @rajasekharporeddy in #799
- [Tunix] Remove the version pin for SGLang. by @copybara-service[bot] in #798
- [Fixes 794] fix transformers=4.57.1 to solve issue42369 in transformers and use c… by @aolemila in #795
- Refactor gemma3 modelConfig to explicitly include all models by @copybara-service[bot] in #792
- [Tunix] Fix nightly regression: remove unnecessary --root-dir argument from TPU nightly regression script. Fix the MATH500 eval script. by @copybara-service[bot] in #796
- use naming utils in tunix cli by @copybara-service[bot] in #736
- [Tunix] Remove GitHub Actions replacement in copybara. Replying on more generic google3 replacement rule by @copybara-service[bot] in #803
- reduce safetensor loading time by @keshavb96 in #760
- [Tunix] Remove env_utils.fs_open from safetensors_loader. fsspec object doesn't have fileno. 3P test is broken: https://github.com/google/tunix/actions/runs/19689186862/job/56403241781?pr=744 by @copybara-service[bot] in #804
- [Tunix] Pass HF_TOKEN to TPU nightly regression tests. by @copybara-service[bot] in #805
- [Tunix] Follow up of cl/836961494. It was out of sync with github PR. by @copybara-service[bot] in #807
- [Tunix] Pin the vLLM TPU Docker image to a specific nightly build version for the TPU tests. by @copybara-service[bot] in #808
- [Tunix] Update tunix nightly regression workflow schedule. Change the cron schedule from 2 AM UTC to 10 AM UTC. by @copybara-service[bot] in #806
- Centralize Flax sharding setup in env_utils by @copybara-service[bot] in #797
- Fix gemma3 grpo shell scripts by @sizhit2 in #791
- [Tunix] Fix GRPO script. by @copybara-service[bot] in #811
- rename all model configs to use "p" instead of "_" for float values by @copybara-service[bot] in #740
- [Tunix] Move model alignment tests from CPU to TPU run dev workflow. by @copybara-service[bot] in #818
- handle the situation when lora_config is not provided by @Hanjun-Dai in #813
- checkpoint_options->checkpointing_options in cli/config.py by @Hanjun-Dai in #814
- [Tunix] Remove EOS token appending to the prompt in vLLM and SGLang sampler. by @copybara-service[bot] in #827
- Fix bos duplication by @Hanjun-Dai in #822
- remove extra flax sharding check by @copybara-service[bot] in #817
- Remove irrelevant text in GRPO example by @copybara-service[bot] in #823
- [TUNIX] Switch to absl.logging in the tunix util file for scripts. by @copybara-service[bot] in #831
- renaming Transformer to Gemma for gemma model by @copybara-service[bot] in #819
- Expand model tests and fix gemma from_params parsing by @copybara-service[bot] in #828
- add missing refactoring to model test by @copybara-service[bot] in #835
- allow user to config project name and run name in wandb by @Hanjun-Dai in #836
- fix the issue when eager mode jax is triggered in undesired places by @Hanjun-Dai in #837
- make TFDS download flag a configurable option by @copybara-service[bot] in #763
- Fix llama RL verl script by @copybara-service[bot] in #839
- Fix ref model compute_logps input sharding issue by @copybara-service[bot] in #846
- Improves the GRPO script to be more configurable. by @wang2yn84 in #840
- remove unused fn by @copybara-service[bot] in #847
- Add support for Dr. GRPO by @copybara-service[bot] in #681
- [Tunix] Update parallel sizes to use ROLLOUT_MESH in grpo_demo. by @copybara-service[bot] in #851
- Fix typo and citation formatting by @selamw1 in #865
- [Bug] Fix/sglang jax support pathways by @aolemila in #860
- [Tunix] Add number of batches argument and reduce nightly regression run time. by @copybara-service[bot] in #866
- Add AgenticRLLearner base class. by @copybara-service[bot] in #829
- Add XM launch for tunix cli by @copybara-service[bot] in #848
- update OSS readme by @copybara-service[bot] in #863
- use env_utils in config_test by @copybara-service[bot] in #872
- check integer type by @copybara-service[bot] in #877
- Add smoke shell scripts to nightly run by @copybara-service[bot] in #855
- use np instead jnp to compute rewards in agentic framework by @copybara-service[bot] in #881
- [Tunix] Support pre-resharding pytrees with different meshes. by @copybara-service[bot] in #882
- Fix the
ValueErrorwhile loading the Gemma model inlogit_distillation.ipynbby @rajasekharporeddy in #870 - change
qwen3_30bmore specific toqwen3_30b_a3bby @copybara-service[bot] in #880 - add qwen4b model config which uses tie embedding by @Hanjun-Dai in #858
- Add codewiki link by @copybara-service[bot] in #886
- [Script] merge grpo_demo_sglang_jax_rollout.py into grpo_demo_llama3_qwen2.py by @aolemila in #868
- Adding support for gemma-X-, llama-X- naming similar to HF by @copybara-service[bot] in #876
- enforce rollout tokens to be in RAM by @copybara-service[bot] in #889
- Adding Automodel interface to Tunix by @copybara-service[bot] in #862
- allow users to import reward module/fn outside tunix folder by @Hanjun-Dai in #852
- Fix breaking config test by @copybara-service[bot] in #901
- use np instead of jnp for reward fn and GRPO group adv by @copybara-service[bot] in #891
- change
qwen3_4b_2507model config added to match HF model ids by @copybara-service[bot] in #897 - [Tunix] Remove EOS token appending to the prompt in vLLM and SGLang sampler. by @copybara-service[bot] in #900
- use model_path instead of model_id for gcs in the cli by @copybara-service[bot] in #904
- change from mock...
Tunix v0.1.5 — Critical Issue Fix for v0.1.4
API Change
This release fixes a critical issue introduced in v0.1.4 that prevented correct functionality.
Users are strongly recommended to upgrade to v0.1.5.
# old:
rl_trainer = GrpoLearner(
grpo_config=grpo_config,
)
# new:
rl_trainer = GrpoLearner(
algo_config=grpo_config,
)
What's Changed
- Remove grpo helper. by @copybara-service[bot] in #771
- Fix the GitHub source links in example notebooks on dpo, grpo and qlora by @rajasekharporeddy in #775
- adding support for cns file downloads in tunix cli by @copybara-service[bot] in #762
- Developing on v0.1.5 now by @wang2yn84 in #776
- Replace
grpo_configwithalgo_configwhile callingGRPOLearnerin GRPO Demo notebook by @rajasekharporeddy in #778 - Lazy load transformers by @copybara-service[bot] in #779
- Fix first_micro_batch_rollout_time by @copybara-service[bot] in #783
Full Changelog: v0.1.4...v0.1.5
Tunix v0.1.4 — JAX 0.8.1 Flax 0.12.1
Highlights
- With the release of JAX to 0.8.1, flax released 0.12.1, therefore remove the qwix version limit.
- Tunix supports DP on vLLM backend
- Enables performance tracer: https://github.com/google/tunix/tree/main/tunix/perf
API Changes
# Old:
cluster_config = rl_cluster_lib.ClusterConfig(
role_to_mesh={
...,
},
training_config=rl_cluster_lib.RLTrainingConfig(
...,
),
rollout_engine=args.rollout_engine,
rollout_config=base_rollout.RolloutConfig(
...,
),
rollout_vllm_model_version=VLLM_MODEL_VERSION,
...,
)
# New:
cluster_config = rl_cluster_lib.ClusterConfig(
role_to_mesh={
...,
},
training_config=rl_cluster_lib.RLTrainingConfig(
...,
),
rollout_engine=args.rollout_engine,
rollout_config=base_rollout.RolloutConfig(
...,
rollout_vllm_model_version=VLLM_MODEL_VERSION,
...,
),
)
New Features
Model Support:
- Added configuration for the Qwen2.5 math-1.5b model.
- Included mobile fine-tuning examples for Gemma 270M.
SGLang Integration:
- Introduced an SGLang JAX sampler.
- Added SGLang JAX mapping for Qwen2 models.
- Enabled SGLang/JAX CI.
Agentic Workflows:
- Added ModelAgent and TaskEnvironment for single-turn agentic workflows.
- Introduced an Agentic GRPOLearner for RL training.
- Provided a script for GRPO agent mode.
- Added tests for agentic_grpo_learner.
- Implemented Agentic GRPO with multi-iteration support and fixes.
Training & Evaluation:
- Added support for ORPO trainer.
- Included scripts for OSS math500 evaluation and deepscalar.
Infrastructure:
- Added Dockerfile and build scripts for Tunix for GKE development.
- Implemented GitHub Actions workflows for Tunix TPU nightly regression.
- Added a plugin-type custom logging backend support in MetricsLogger.
Improvements
Model Loading & Configuration:
- Refactored model loading from Flax Orbax checkpoints, including fixes for Gemma and Gemma2.
- Refactored gemma modelConfig to explicitly include all models.
- Relaxed frozen configuration for models.
Performance & Efficiency:
- Improved speed of safetensor loading.
- Added per-Python-thread timeline and export of perf metrics to metrics_logger.
- Rewrote the performance tracer with a new data model.
- Enabled vLLM Data Parallelism on Tunix.
Architecture & Refactoring:
- Moved agentic code out of the experimental folder.
- Moved rollout related configs from cluster config to rollout_config.
- Updated trajectory engine code.
- Updated RolloutOrchestrator logic.
- Implemented a concrete naming structure for parsing HuggingFace model IDs.
- Updated model module to prevent AttributeError with pytree=false.
Usability:
- Updated vanilla sampler to accept single strings.
- Made put_exception in GRPO agentic learner asynchronous.
- Enabled micro_batch_size for rollout and reference models in the PPO learner.
- Added support for user-defined rollout engines.
- Added Kaggle and GitHub buttons to Tunix example notebooks.
- Improved HBM usage reporting in multi-process SPMD.
Internal:
- Refactored TPU tests to run separately based on HF_TOKEN requirements.
- Updated Tunix GitHub Actions to trigger on push to main.
- Moved Docker files to the root directory.
- Added backward compatibility for set_mesh.
Bug Fixes
- Fixed broken CI due to vLLM.
- Fixed vLLM driver tests.
- Improved test collection to only include target tests.
- Fixed a conditional issue in the Tunix Gemma implementation.
- Fixed nnx.remat usage with bound methods.
- Fixed the OSS GRPO training script.
- Fixed Qwen2 mapping for SGLang/JAX.
- Fixed an incorrect loss type issue.
- Fixed max_step initialization when profiling.
- Fixed issues with multiple metrics loggers.
- Reduced test flakiness.
- Fixed broken links in README.md.
- Corrected algo_config naming in GRPOLearner.
- Fixed the get_logprobs_from_vllm_output utility function.
- Fixed TypeError in example notebooks by updating mesh indexing (MESH[0] to len(MESH[0])).
- Addressed a very weird bug. (Details pending)
Documentation
- Fixed documentation build for ReadTheDocs.
- Minor fix on grpo_demo description.
- Added README for SGLang JAX.
- Updated docstring usage for dataclasses.
Internal & Tooling
- Automated GitHub issue assignment to all engineers.
- Converted notebook files (.ipynb) to Python scripts (_nb.py) and removed Jupyter cell markers.
- Updated debug logging.
- Pinned Qwix version to 0.1.1 (and later removed the pin).
- Ensured latest dependencies are installed by forcing reinstall.
- Temporarily disabled SGLang tests.
- Removed gcsfs from pyproject.toml dependencies.
Detailed PRs
- Fix the gemma2 loading from flax orbax checkpoint. [1/N] Refactor model loading by @wang2yn84 in #595
- Fix gemma model. [2/N] Refactor model loading by @wang2yn84 in #596
- Fix broken CI due to vllm. by @wang2yn84 in #600
- Fix vllm driver test. by @wang2yn84 in #601
- Fix the tests collection so that it only collects the target tests. by @wang2yn84 in #604
- Adds async tests for vllm server mode. by @wang2yn84 in #603
- This is trying to fix a very weird bug... by @copybara-service[bot] in #605
- [Tunix] Add qwen2.5 math-1.5b model config by @copybara-service[bot] in #598
- add sglang jax sampler by @pathfinder-pf in #553
- add qwen2 sglang jax mapping by @pathfinder-pf in #608
- Move agentic out from experimental folder by @copybara-service[bot] in #597
- Fix conditional in Tunix Gemma implementation by @copybara-service[bot] in #609
- vanilla sampler accept single str by @copybara-service[bot] in #613
- Auto assign github issues to all the eng. by @wang2yn84 in #615
- test: add model alignment test for qwen2 models. by @copybara-service[bot] in #607
- add sglang jax readme by @pathfinder-pf in #617
- relax frozen config for model by @copybara-service[bot] in #619
- support user defined rollout engine by @copybara-service[bot] in #618
- The grpo_trainer.train call was updated to use the train_dataset variable. by @copybara-service[bot] in #626
- Fix
nnx.rematusage with bound methods. by @copybara-service[bot] in #621 - Move rollout related configs from cluster config to rollout_config. by @wang2yn84 in #630
- Log individual trajetory rewards, rather than average across microbatch by @copybara-service[bot] in #571
- speed up safetensor loading by @copybara-service[bot] in #629
- Make put_exception as async by @copybara-service[bot] in #635
- Add
ModelAgentandTaskEnvironmentfor single-turn agentic workflows. by @copybara-service[bot] in #634 - style: unify abc annotation usage. by @copybara-service[bot] in #620
- test: add model alignment test for qwen3 models. by @copybara-service[bot] in #632
- Adds [GemmaChatTemplateParser] to support message formatting for Gemma models, including handling of its unique system prompt and token structure. by @copybara-service[bot] in #639
- Trajectory engine code update by @copybara-service[bot] in #645
- Not donating
modelfor jitted eval step. by @copybara-service[bot] in #648 - Update RolloutOrchestrator by @copybara-service[bot] in #640
- Add Agentic GRPOLearner for RL training. by @copybara-service[bot] in #646
- [JAX] Replace reference to jax._src.lib.xla_client.SingleDeviceSharding with jax.sharding.SingleDeviceSharding, which is its public name. by @copybara-service[bot] in #653
- Remove explicit sharding after applying LoRA. by @copybara-service[bot] in #636
- Fix max_step when initializing a profiler by @copybara-service[bot] in #651
- Use correct docstring for dataclasses by @copybara-service[bot] in #661
- OSS math500 eval and deepscalar script. by @copybara-service[bot] in #658
- Fix the OSS GRPO training script. by @wang2yn84 in #660
- fix: qwen2 mapping_sglang_jax by @jimoosciuc in #665
- Fix the wrong loss type issue. by @copybara-service[bot] in #669
- [Tunix] Enable micro_batch_size for rollout and reference models in PPO learner. by @copybara-service[bot] in #668
- Convert notebook file to py scipt. by @copybara-service[bot] in #671
- [Tunix] Update the debug logging. by @copybara-service[bot] in #674
- [Tunix] Minor fix on grpo_demo description by @copybara-service[bot] in #677
- Add GitHub Actions work...
Tunix v0.1.3 — JAX 0.8 and new Qwen / Llama3 model support
A maintenance and feature release focused on TPU readiness, test hardening, and model additions. Highlights include a JAX upgrade, SFT/CI improvements, new Qwen and Llama3 model variants, and multiple bugfixes across training and distillation tooling.
Highlights
- Bumped JAX to 0.8.0 for improved compatibility and performance. Jax 0.7.2 has performance degradation on compilation and we are passing over this version.
- Add vLLM TPU to the dev mode.
- Qwen2.5 (including 1.5B) and Llama3 (70B & 405B) support added.
What's Changed
- Bump up Tunix to v0.1.3 for dev by @wang2yn84 in #551
- more unittest by @copybara-service[bot] in #550
- Move CLI utils test to CPU test. by @copybara-service[bot] in #532
- Clean up vllm tests. by @wang2yn84 in #556
- fix qwen2.5 model by @copybara-service[bot] in #558
- add qwen2.5 1.5b by @copybara-service[bot] in #559
- make shell scripts executable by @sizhit2 in #545
- Refactor the weight mapping config by @wang2yn84 in #562
- [Tunix] Minor change to remove unnecessary type casting by @copybara-service[bot] in #565
- Make sft smoke test executable and runnable in tpu workflow. by @copybara-service[bot] in #552
- Fix broken distillation notebook by @copybara-service[bot] in #563
- Modify DPO loss function by @copybara-service[bot] in #564
- Async rollout code update by @copybara-service[bot] in #566
- Exporting the CheckpointManager class by @copybara-service[bot] in #572
- Fixes copy bara service, the replace rule doesn't work by @copybara-service[bot] in #575
- Fix PeftTrainer and DPO bugs by @copybara-service[bot] in #580
- add build test for /models. by @copybara-service[bot] in #577
- Add test import check for all build target under /rl, /utils, /tests folder. by @copybara-service[bot] in #576
- Bump up Jax version to 0.8.0 by @wang2yn84 in #581
- Fix metric logging for DPO by @copybara-service[bot] in #583
- add llama3 70 & 405b by @copybara-service[bot] in #589
Full Changelog: v0.1.2...v0.1.3
Tunix v0.1.2: Expanded Model Support and Enhanced Flexibility
This release of Tunix introduces support for new models, enhances core functionalities for more flexible and efficient workflows, and includes several important fixes.
Highlights
- Expanded Model Support: We've added a configuration for
qwen-8band ported the Llama3 example to the Tunix CLI. Additionally, GRPO disaggregatedllama3.1-70bis now supported through MaxText, including checkpoint saving. - Enhanced Flexibility: Users can now specify a different data type for the rollout model and take advantage of more flexible PyTree support in the checkpoint manager. This release also introduces flexible collect modes and tokenization support, along with support for multiple EOS tokens in the vanilla sampler.
Other Changes
- Downgraded Jax version to 0.7.1 in prod mode due to performance regression, dev mode still supports Jax v0.7.2
- Fixes to the front page
pip installcommand and GRPO examples. - Improvements to the checkpoint manager and resharding library.
- Added a backward compatibility test for Orbax checkpoint restoration.
- Various code simplifications, refactoring, and documentation updates.
Full Changelog: v0.1.1...v0.1.2
What's Changed
- [Tunix] Allow specifying a different data type for the rollout model. by @copybara-service[bot] in #513
- Fix the front page pip install command. by @wang2yn84 in #515
- Remove prompt_template.py by @copybara-service[bot] in #514
- Tool code update by @copybara-service[bot] in #471
- simplify micro batching config by @copybara-service[bot] in #516
- Add explicit imports for specific TFDS datasets. by @copybara-service[bot] in #519
- Adding a qwen-8b config by @copybara-service[bot] in #522
- Environment code update by @copybara-service[bot] in #512
- [Tunix] Update the checkpoint manager with more flexible PyTree. by @copybara-service[bot] in #337
- Add BibTeX by @copybara-service[bot] in #528
- Ensure github workflow failure block on presubmit by @copybara-service[bot] in #487
- [Tunix] fix the grpo example which blocks copybara presubmit by @copybara-service[bot] in #531
- [Tunix] Add backward compatibility test for Orbax checkpoint restoration. by @copybara-service[bot] in #530
- [Tunix] Update reshard lib to respect logical axis rules. by @copybara-service[bot] in #518
- support multiple eos tokens in vanilla sampler by @copybara-service[bot] in #525
- Add sft smoke test. by @copybara-service[bot] in #533
- Port Precur.AI llama3 example to Tunix CLI by @copybara-service[bot] in #520
- Make GRPO disaggregated llama3.1 70b work with pathways including ckpt saving by @A9isha in #527
- Rename 'convert_messages_to_tokens_and_masks' to 'tokenize_and_generate_masks' by @copybara-service[bot] in #539
- Change docker version to jax0.7.1_rev1 by @copybara-service[bot] in #544
- Add flexible collect modes and tokenization support. by @copybara-service[bot] in #526
- fix mtnt import by @copybara-service[bot] in #546
- update function docstring for tokenize_and_generate_masks by @copybara-service[bot] in #547
- Move grpo shell scripts. by @copybara-service[bot] in #543
- minor simplification by @copybara-service[bot] in #541
Full Changelog: v0.1.1...v0.1.2
Tunix v0.1.1 — Improved Stability, New Features, and TPU Optimizations
This release focuses on improving performance and stability across TPU and Kaggle environments, introducing new utilities for agentic RL workflows, and adding broader model and configuration support. It also includes several important bug fixes and developer experience improvements.
Run Tunix on Kaggle TPU
We’re excited to announce that Tunix can now be launched directly in Kaggle notebooks with TPU acceleration — making it easier than ever to experiment, prototype, and run reinforcement learning workflows without complex setup.
Key highlights
First-class TPU support on Kaggle – run GRPO and other RL pipelines end-to-end in a Kaggle notebook.
Pre-configured runtime – no manual dependency juggling needed; version compatibility and performance tuning are handled automatically.
Launch the notebook here:
Knowledge Distillation Demo
QLoRA Demo
DPO Demo
GRPO Demo
New Features & Improvements
- Model & Training Options
- Added support for Gemma-3-270M model configuration.
- Enabled setting default parameter dtype for Gemma-3 models.
- Added remat options to models to improve memory efficiency.
- Created a new list container type to support both Flax ≤0.11.2 and ≥0.12.0 versions.
- Pathways & TPU Performance
- Introduced experimental pre-sharding (experimental_reshard) for Pathways on Cloud TPU.
- Improved weight synchronization logic to handle KV head duplication.
- Disabled certain profiler options by default to improve stability on Pathways backend.
- Configuration & CLI Improvements
- Enabled generic creation of optax.optimizer and optax.learning_rate_schedule directly from CLI.
- Relaxed JAX version constraints to ensure compatibility with Kaggle images.
- Added minimum resource requirements for launch scripts in the README.
- Documentation
- Added ReadTheDocs link in README.
- Expanded external notebooks with step-by-step guidance for long-running tasks.
Bug Fixes
- Fixed a bug in reward function logic causing incorrect training signals.
- Fixed a checkpoint handling issue where Colab failed to locate the final checkpoint and now cleans up intermediate directories.
- Fixed Kaggle image performance issues.
- Fixed type errors in agents/ modules.
- Optimized masked index lookups using jnp.where for better runtime efficiency.
- Resharded prompt and completion tokens to the REFERENCE mesh when rollout and reference models are distributed.
Dependency & Version Updates
- JAX pinned to 0.7.1 and libtpu downgraded to resolve Cloud TPU performance regressions.
- Relaxed JAX version requirement for Kaggle compatibility.
Full Changelog:
- Bump up the version to v0.1.0 by @wang2yn84 in #446
- Delete this notebook as it's redundant now. Prepare for the release. by @copybara-service[bot] in #445
- Add min resources requirements for launch scripts to README by @copybara-service[bot] in #424
- Enable generic creation of optax.optimizer, optax.learning_rate_schedule from cli by @copybara-service[bot] in #435
- [Tunix] Reshard prompt and completion tokens to the
REFERENCEmesh before computing reference log probabilities if needed. This is needed when rollout and reference are distributed. by @copybara-service[bot] in #451 - Add Typed... types to ArrayLike. by @copybara-service[bot] in #461
- Downgrad the Jax/libtpu version to resolve performance issue on Cloud TPU by @wang2yn84 in #465
- Pin Jax version to 0.7.1 by @wang2yn84 in #468
- add a comment for version pinning by @copybara-service[bot] in #469
- update internal grpo notebook by @copybara-service[bot] in #463
- Adds experimental pre-shard to Pathways on Cloud
experimental_reshardby @copybara-service[bot] in #473 - Relax the jax version requirement to get a working Kaggle image. by @wang2yn84 in #474
- add remat options to model by @copybara-service[bot] in #470
- Create a new list container type to support both flax<=0.11.2 and >=0.12.0. by @copybara-service[bot] in #476
- Enable setting default param dtype for Gemma 3 model by @copybara-service[bot] in #482
- new reward functions and unit tests by @copybara-service[bot] in #472
- add readthedoc in readme by @copybara-service[bot] in #485
- Refactor: Optimize masked index lookup using jnp.where by @copybara-service[bot] in #490
- Fix the slow kaggle image issue. by @copybara-service[bot] in #488
- Fix reward function bug by @copybara-service[bot] in #486
- Fix the colab that can't find the final checkpoint. Cleanup the intermediate checkpoint directory. by @wang2yn84 in #495
- Fix type errors within agents/ by @copybara-service[bot] in #496
- Update weight sync logic to handle KV head duplication by @wenxindongwork in #464
- Added grpo test to tpu-tunix tests by @mydatascience in #447
- Add Gemma3-270M model configuration support #500 by @chethanuk in #501
- Disable setting specific profiler options on Pathways backend. by @copybara-service[bot] in #494
- Adds description to external notebook for steps that take long. by @wang2yn84 in #481
New Contributors
- @chethanuk made their first contribution in #501
Full Changelog: v0.1.0...v0.1.1
Tunix v0.1.0 — First Public Release of Google’s Reinforcement Learning Library for LLM Post-Training
We’re thrilled to announce Tunix v0.1.0, the first public release of Google’s lightweight, JAX-native library for post-training large language models (LLMs) using both reinforcement learning (RL) and supervised fine-tuning (SFT). Tunix is built for researchers and production teams who want maximum control and scalability when aligning and improving foundation models — from data loading to distributed rollout and training on TPUs.
Highlights of v0.1.0
SFT (Supervised Fine-Tuning): Seamlessly train your LLMs with labeled datasets to bootstrap alignment before RL or as a standalone approach.
High-efficiency Reinforced Learning (RL) policies such as GRPO, GSPO, PPO, DPO, etc. designed for instruction-tuning and reward-based LLM alignment.
End-to-End RL Pipeline: From reward function definition to rollout and policy optimization, everything is fully integrated and composable.
Multi-Model Support: Works out of the box with leading open-weight models, including Gemma 2/3, LLaMA 3, and Qwen 2/3 — and can be extended to other Hugging Face models with minimal effort.
Seamless TPU / CPU Execution: Tunix is built on top of JAX and Flax with first-class support for multi-device and multi-host environments.
Dataset Flexibility: Use tensorflow datasets, Kaggle datasets, or custom Grain datasets with minimal changes.
Modular Design: Clean abstractions for samplers, reward functions, trainers, and optimizers — making it easy to extend or plug into your own workflows.
Get Started
Install Tunix from PyPI:
pip install google-tunix[prod]
We recommend starting with the GRPO demo notebook
to see how reinforcement learning can be applied to real LLM training.
Tunix 0.1.0.dev1 – Development Preview
This is the first development release of Tunix, Google’s reinforcement learning library for language model post-training.
Note: This is a pre-release (.dev1) version meant for testing and feedback.
APIs and behavior may change before the official 0.1.0 stable release.
Use this build to validate early integrations, experiment with new features, and provide feedback.
Install this dev release:
pip install --pre google-tunix[prod]==0.1.0.dev1Tunix 0.1.0.dev0 – Development Preview
This is the first development release of Tunix, Google’s reinforcement learning library for language model post-training.
Note: This is a pre-release (.dev0) version meant for testing and feedback.
APIs and behavior may change before the official 0.1.0 stable release.
Use this build to validate early integrations, experiment with new features, and provide feedback.
Install this dev release:
pip install --pre google-tunix==0.1.0.dev0