Tunix v0.1.3 — JAX 0.8 and new Qwen / Llama3 model support
A maintenance and feature release focused on TPU readiness, test hardening, and model additions. Highlights include a JAX upgrade, SFT/CI improvements, new Qwen and Llama3 model variants, and multiple bugfixes across training and distillation tooling.
Highlights
- Bumped JAX to 0.8.0 for improved compatibility and performance. Jax 0.7.2 has performance degradation on compilation and we are passing over this version.
- Add vLLM TPU to the dev mode.
- Qwen2.5 (including 1.5B) and Llama3 (70B & 405B) support added.
What's Changed
- Bump up Tunix to v0.1.3 for dev by @wang2yn84 in #551
- more unittest by @copybara-service[bot] in #550
- Move CLI utils test to CPU test. by @copybara-service[bot] in #532
- Clean up vllm tests. by @wang2yn84 in #556
- fix qwen2.5 model by @copybara-service[bot] in #558
- add qwen2.5 1.5b by @copybara-service[bot] in #559
- make shell scripts executable by @sizhit2 in #545
- Refactor the weight mapping config by @wang2yn84 in #562
- [Tunix] Minor change to remove unnecessary type casting by @copybara-service[bot] in #565
- Make sft smoke test executable and runnable in tpu workflow. by @copybara-service[bot] in #552
- Fix broken distillation notebook by @copybara-service[bot] in #563
- Modify DPO loss function by @copybara-service[bot] in #564
- Async rollout code update by @copybara-service[bot] in #566
- Exporting the CheckpointManager class by @copybara-service[bot] in #572
- Fixes copy bara service, the replace rule doesn't work by @copybara-service[bot] in #575
- Fix PeftTrainer and DPO bugs by @copybara-service[bot] in #580
- add build test for /models. by @copybara-service[bot] in #577
- Add test import check for all build target under /rl, /utils, /tests folder. by @copybara-service[bot] in #576
- Bump up Jax version to 0.8.0 by @wang2yn84 in #581
- Fix metric logging for DPO by @copybara-service[bot] in #583
- add llama3 70 & 405b by @copybara-service[bot] in #589
Full Changelog: v0.1.2...v0.1.3