Commit 7f85833
Update torchtitan and train.py (#21)
* activations on CUDA offloaded
* add save_for_all_ranks config
* update torchtitan
* update train.py
* use build_loss_fn
* add get_nparams_and_flops
* remove unused import
* Fix isort issues
---------
Co-authored-by: Yu Zhang <yzhang.cs@outlook.com>1 parent c949efe commit 7f85833
File tree
6 files changed
+107
-143
lines changed- 3rdparty
- flame
- models
- tools
6 files changed
+107
-143
lines changedSubmodule flash-linear-attention updated 157 files
Submodule torchtitan updated 92 files
- .github/workflows/integration_test_8gpu.yaml-3
- .github/workflows/lint.yaml+1-1
- CONTRIBUTING.md+5-3
- README.md+14-13
- docs/checkpoint.md+2-2
- docs/composability.md+1-1
- docs/extension.md+58
- docs/fsdp.md-1
- docs/torchft.md+50
- multinode_trainer.slurm+1-1
- pyproject.toml+1-1
- run_train.sh+2-2
- scripts/estimate/estimation.py+16-13
- scripts/estimate/run_memory_estimation.sh+1-1
- scripts/generate/README.md+1-1
- scripts/generate/run_llama_generate.sh+1-1
- scripts/generate/test_generate.py+1-1
- tests/assets/argparser_example.py+16
- tests/integration_tests.py+77-65
- tests/unit_tests/test_dataset_checkpointing.py+1-1
- tests/unit_tests/test_job_config.py+72-14
- tests/unit_tests/test_model_converter.py+7-6
- tests/unit_tests/test_train_spec.py+8-11
- torchtitan/__init__.py+1
- torchtitan/components/checkpoint.py+13-46
- torchtitan/components/dataloader.py+6-2
- torchtitan/components/loss.py+13-2
- torchtitan/components/lr_scheduler.py+174
- torchtitan/components/metrics.py+14-12
- torchtitan/components/optimizer.py+2-155
- torchtitan/config_manager.py+154-103
- torchtitan/datasets/tokenizer/tiktoken.py+5
- torchtitan/distributed/pipeline.py+8-12
- torchtitan/distributed/utils.py+8-4
- torchtitan/experiments/README.md+20
- torchtitan/experiments/__init__.py+1-4
- torchtitan/experiments/deepseek_v3/USAGE.md+8
- torchtitan/experiments/deepseek_v3/download.py+53-4
- torchtitan/experiments/deepseek_v3/generate.py+249
- torchtitan/experiments/deepseek_v3/inference.sh+15
- torchtitan/experiments/deepseek_v3/model.py+215-216
- torchtitan/experiments/deepseek_v3/model_config.py+27-6
- torchtitan/experiments/deepseek_v3/requirements.txt+5
- torchtitan/experiments/deepseek_v3/run.py+82-30
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py+2-2
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py+114-12
- torchtitan/experiments/kernels/triton_mg_group_gemm/autotuner.py+69
- torchtitan/experiments/kernels/triton_mg_group_gemm/autows_version.py+1.9k
- torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py+621
- torchtitan/experiments/kernels/triton_mg_group_gemm/fast_debug.py+403
- torchtitan/experiments/kernels/triton_mg_group_gemm/grid_stride_kernels_old.py+536
- torchtitan/experiments/kernels/triton_mg_group_gemm/mg_grouped_gemm.py+2.1k
- torchtitan/experiments/kernels/triton_mg_group_gemm/mg_grouped_gemm_benchmark_results.png
- torchtitan/experiments/kernels/triton_mg_group_gemm/profile_forward.py+80
- torchtitan/experiments/kernels/triton_mg_group_gemm/simple_MoE.py+875
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py+299
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py+1.3k
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py+126
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py+239
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_backwards.py+174
- torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py+80
- torchtitan/experiments/kernels/triton_mg_group_gemm/unit_test_mg.py+80
- torchtitan/experiments/kernels/triton_mg_group_gemm/verify_fp8_forward.py+13
- torchtitan/experiments/multimodal/__init__.py+32
- torchtitan/experiments/multimodal/check_padding_mm.py+109
- torchtitan/experiments/multimodal/mm_collator.py+227
- torchtitan/experiments/multimodal/mm_dataset.py+268
- torchtitan/experiments/multimodal/requirements.txt+1
- torchtitan/experiments/multimodal/tokenizer/tiktoken.py+232
- torchtitan/experiments/multimodal/transform.py+185
- torchtitan/experiments/multimodal/utils.py+437
- torchtitan/experiments/simple_fsdp/README.md+40
- torchtitan/experiments/simple_fsdp/__init__.py+33
- torchtitan/experiments/simple_fsdp/model.py+18
- torchtitan/experiments/simple_fsdp/parallelize_llama.py+98
- torchtitan/experiments/simple_fsdp/simple_fsdp.py+194
- torchtitan/experiments/simple_fsdp/tests/__init__.py+5
- torchtitan/experiments/simple_fsdp/tests/test_numerics.py+128
- torchtitan/models/llama/__init__.py+6-5
- torchtitan/models/llama/model.py+73-3
- torchtitan/models/llama/parallelize_llama.py+17-12
- torchtitan/models/llama/pipeline_llama.py+23-9
- torchtitan/models/llama/train_configs/debug_model.toml+6-6
- torchtitan/models/llama/train_configs/llama3_405b.toml+6-6
- torchtitan/models/llama/train_configs/llama3_70b.toml+5-5
- torchtitan/models/llama/train_configs/llama3_8b.toml+5-5
- torchtitan/models/llama_multimodal/model.py+7-9
- torchtitan/models/norms.py+1-32
- torchtitan/protocols/train_spec.py+21-8
- torchtitan/tools/profiling.py+6-1
- torchtitan/tools/utils.py+11-67
- torchtitan/train.py+422-344
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
466 | 466 | | |
467 | 467 | | |
468 | 468 | | |
469 | | - | |
470 | 469 | | |
471 | | - | |
| 470 | + | |
472 | 471 | | |
473 | | - | |
| 472 | + | |
474 | 473 | | |
475 | | - | |
476 | | - | |
477 | | - | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
478 | 478 | | |
479 | 479 | | |
480 | 480 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
151 | 151 | | |
152 | 152 | | |
153 | 153 | | |
154 | | - | |
155 | | - | |
156 | | - | |
157 | | - | |
158 | | - | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
159 | 163 | | |
160 | 164 | | |
161 | 165 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| 7 | + | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
10 | | - | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
11 | 19 | | |
12 | 20 | | |
13 | 21 | | |
| |||
28 | 36 | | |
29 | 37 | | |
30 | 38 | | |
31 | | - | |
| 39 | + | |
32 | 40 | | |
33 | | - | |
| 41 | + | |
0 commit comments