Skip to content

Commit 83b6ac4

Browse files
Merge remote-tracking branch 'upstream/main'
2 parents 38cf6ab + 336803f commit 83b6ac4

21 files changed

+265
-103
lines changed

examples/configs/distillation_math.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ policy: &POLICY_BASE
155155
use_custom_fsdp: false
156156
data_parallel_sharding_strategy: "optim_grads_params"
157157

158+
fp8_cfg:
159+
enabled: false
160+
fp8: "e4m3"
161+
fp8_recipe: "blockwise"
162+
fp8_param: false
163+
158164
scheduler:
159165
- name: "torch.optim.lr_scheduler.LinearLR"
160166
kwargs:

examples/configs/distillation_math_megatron.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ policy: &POLICY_BASE
106106
use_custom_fsdp: false
107107
data_parallel_sharding_strategy: "optim_grads_params"
108108

109+
fp8_cfg:
110+
enabled: false
111+
fp8: "e4m3"
112+
fp8_recipe: "blockwise"
113+
fp8_param: false
114+
109115
generation:
110116
backend: "vllm"
111117
max_new_tokens: ${..max_total_sequence_length} # refer to local policy/teacher config

examples/configs/dpo.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,13 @@ policy:
177177
overlap_param_gather: true
178178
data_parallel_sharding_strategy: "optim_grads_params"
179179
use_custom_fsdp: false
180-
180+
181+
fp8_cfg:
182+
enabled: false
183+
fp8: "e4m3"
184+
fp8_recipe: "blockwise"
185+
fp8_param: false
186+
181187
data:
182188
max_input_seq_length: ${policy.max_total_sequence_length}
183189
shuffle: true

examples/configs/grpo_math_1B.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,11 @@ policy:
189189
use_custom_fsdp: false
190190
data_parallel_sharding_strategy: "optim_grads_params"
191191

192-
fp8_cfg: null
192+
fp8_cfg:
193+
enabled: false
194+
fp8: "e4m3"
195+
fp8_recipe: "blockwise"
196+
fp8_param: false
193197

194198
env_vars: null
195199

examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-megatron-fp8-e2e.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ policy:
3333
lr_warmup_init: 5.0e-08
3434
fp8_cfg:
3535
enabled: true
36-
fp8: e4m3
37-
fp8_recipe: blockwise
38-
fp8_param: false
3936
env_vars:
4037
NVTE_FP8_BLOCK_SCALING_FP32_SCALES: '1'
4138
generation:

examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron-fp8-e2e.yaml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ policy:
2828
apply_rope_fusion: false
2929
fp8_cfg:
3030
enabled: true
31-
fp8: e4m3
32-
fp8_recipe: blockwise
33-
fp8_param: false
3431
optimizer:
3532
lr: 1.0e-06
3633
use_precision_aware_optimizer: false
@@ -43,10 +40,9 @@ policy:
4340
precision: fp8
4441
use_deep_gemm: true
4542
gpu_memory_utilization: 0.5
46-
quantization_ignored_layer_kws: [
47-
a_proj,
48-
b_proj
49-
]
43+
quantization_ignored_layer_kws:
44+
- a_proj
45+
- b_proj
5046
logger:
5147
monitor_gpus: false
5248
wandb:

examples/configs/recipes/llm/performance/grpo-llama3.1-8b-instruct-2n8g-fp8-async-1off.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@ policy:
55
megatron_cfg:
66
fp8_cfg:
77
enabled: true
8-
fp8: "e4m3"
9-
fp8_recipe: "blockwise"
10-
fp8_param: false
118
env_vars:
129
NVTE_FP8_BLOCK_SCALING_FP32_SCALES: "1"
1310
generation:

examples/configs/rm.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ policy:
128128
overlap_param_gather: false
129129
data_parallel_sharding_strategy: "optim_grads_params"
130130

131+
fp8_cfg:
132+
enabled: false
133+
fp8: "e4m3"
134+
fp8_recipe: "blockwise"
135+
fp8_param: false
131136

132137
data:
133138
max_input_seq_length: ${policy.max_total_sequence_length}

examples/configs/sft.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,12 @@ policy:
175175
data_parallel_sharding_strategy: "optim_grads_params"
176176
use_custom_fsdp: false
177177

178+
fp8_cfg:
179+
enabled: false
180+
fp8: "e4m3"
181+
fp8_recipe: "blockwise"
182+
fp8_param: false
183+
178184
data:
179185
max_input_seq_length: ${policy.max_total_sequence_length}
180186
add_bos: true

examples/configs/sft_openmathinstruct2_megatron.yaml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,11 @@ policy:
100100
env_vars:
101101
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False"
102102

103-
## fp8 training currently not supported
104-
#fp8_cfg:
105-
# enabled: true
106-
# fp8: hybrid
107-
# fp8_recipe: delayed
108-
# fp8_param: true # false gives the following error: "RuntimeError: /TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:116 in function CanonicalizeGemmInput: Assertion failed: !is_fp8_dtype(ret.Atype). Input A is missing column-wise usage"
109-
# fp8_dot_product_attention: false #true
110-
# fp8_multi_head_attention: false #true
103+
fp8_cfg:
104+
enabled: false
105+
fp8: "e4m3"
106+
fp8_recipe: "blockwise"
107+
fp8_param: false
111108

112109
dynamic_batching:
113110
enabled: false

0 commit comments

Comments
 (0)