Skip to content

Commit d80b8d1

Browse files
committed
Merge dev/john/titan-ptc into release/v25.10
2 parents 653f0ec + f89348d commit d80b8d1

File tree

25 files changed

+745
-36
lines changed

25 files changed

+745
-36
lines changed

examples/run_local_pretrain.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ docker_podman_proxy run --rm \
143143
--env TORCHTITAN_PATH \
144144
--env MAXTEXT_PATH \
145145
--env BACKEND_PATH \
146+
--env REBUILD_PRIMUS_TURBO \
146147
"${ENV_ARGS[@]}" \
147148
--ipc=host --network=host \
148149
--device=/dev/kfd --device=/dev/dri \

examples/run_pretrain.sh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,30 @@ export NVTE_CK_USES_BWD_V3=${NVTE_CK_USES_BWD_V3:-0}
277277
# Note: Disable fp32 atomic due if you find any accuracy issue.
278278
export PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32=${PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32:-0}
279279

280+
# install primus turbo from source
281+
export REBUILD_PRIMUS_TURBO=${REBUILD_PRIMUS_TURBO:-0}
282+
if [ "$REBUILD_PRIMUS_TURBO" == "1" ]; then
283+
LOG_INFO "Rebuilding Primus Turbo from source..."
284+
mkdir -p "/workspace/turbo"
285+
cd "/workspace/turbo"
286+
287+
# Clean up old directory if exists to avoid git clone conflicts
288+
if [ -d "Primus-Turbo" ]; then
289+
LOG_INFO "Removing existing Primus-Turbo directory..."
290+
rm -rf Primus-Turbo
291+
fi
292+
293+
git clone https://github.com/AMD-AGI/Primus-Turbo.git --recursive
294+
cd Primus-Turbo
295+
pip3 install -r requirements.txt
296+
# Set GPU_ARCHS to compile Turbo for multiple AMD GPU architectures.
297+
GPU_ARCHS="gfx942;gfx950" pip3 install --no-build-isolation .
298+
cd "${PRIMUS_PATH}"
299+
LOG_INFO "Rebuilding Primus Turbo from source done."
300+
else
301+
LOG_INFO "Skip Primus Turbo rebuild. REBUILD_PRIMUS_TURBO=$REBUILD_PRIMUS_TURBO"
302+
fi
303+
280304
# nvte debug envs
281305
export NVTE_DEBUG=0 # 0, 1
282306
export NVTE_DEBUG_LEVEL=0 # 0, 1, 2

examples/run_slurm_pretrain.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,6 @@ srun -N "${NNODES}" \
5757
export NNODES=\${SLURM_NNODES}
5858
export NODE_RANK=\${SLURM_PROCID}
5959
export GPUS_PER_NODE=\${SLURM_GPUS_ON_NODE}
60+
export REBUILD_PRIMUS_TURBO=\${REBUILD_PRIMUS_TURBO}
6061
bash ${SCRIPT_DIR}/run_local_pretrain.sh \"\$@\" 2>&1 | tee ${LOG_FILE}
6162
" bash "$@"

examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ modules:
1313
model: deepseek_v3_16b.yaml
1414
overrides:
1515
profiling:
16-
enable_profiling: false
16+
enable_profiling: true
1717
save_traces_folder: "profile_trace"
1818
profile_freq: 10
1919
enable_memory_snapshot: false
2020
save_memory_snapshot_folder: "memory_snapshot"
2121

2222
metrics:
23-
log_freq: 10
23+
log_freq: 1
2424
disable_color_printing: false
2525
enable_tensorboard: false
2626
save_tb_folder: "tb"
@@ -38,11 +38,12 @@ modules:
3838
min_lr_factor: 0.1
3939

4040
training:
41+
debug_moe_force_load_balance: true
4142
local_batch_size: 4
4243
seq_len: 4096
4344
max_norm: 1.0 # grad norm clipping
44-
steps: 1000
45-
dataset: "c4" # supported datasets: c4_test (2K), c4 (177M)
45+
steps: 15
46+
dataset: "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4647

4748
parallelism:
4849
data_parallel_replicate_degree: 1
@@ -69,8 +70,16 @@ modules:
6970

7071
compile:
7172
enable: true
72-
components: ["loss"] # ["model", "loss"]
73+
components: ["model", "loss"] # ["model", "loss"]
7374

75+
primus_turbo:
76+
enable_primus_turbo: true
77+
use_turbo_mx_linear: false
78+
use_turbo_float8_linear: true
79+
enable_attention_float8: false
80+
use_turbo_grouped_mm: true
81+
use_moe_fp8: false
82+
7483
# quantize:
7584
# linear:
7685
# float8:

examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ modules:
1313
model: deepseek_v3_671b.yaml
1414
overrides:
1515
profiling:
16-
enable_profiling: false
16+
enable_profiling: true
1717
save_traces_folder: "profile_trace"
1818
profile_freq: 10
1919
enable_memory_snapshot: false
2020
save_memory_snapshot_folder: "memory_snapshot"
2121

2222
metrics:
23-
log_freq: 10
23+
log_freq: 1
2424
disable_color_printing: false
2525
enable_tensorboard: false
2626
save_tb_folder: "tb"
@@ -38,11 +38,12 @@ modules:
3838
min_lr_factor: 0.1
3939

4040
training:
41-
local_batch_size: 4
41+
debug_moe_force_load_balance: true
42+
local_batch_size: 16
4243
seq_len: 4096
4344
max_norm: 1.0 # grad norm clipping
44-
steps: 1000
45-
dataset: "c4" # supported datasets: c4_test (2K), c4 (177M)
45+
steps: 15
46+
dataset: "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4647

4748
parallelism:
4849
data_parallel_replicate_degree: 1
@@ -52,7 +53,7 @@ modules:
5253
enable_async_tensor_parallel: false
5354
pipeline_parallel_degree: 1
5455
pipeline_parallel_schedule: "Interleaved1F1B"
55-
expert_parallel_degree: 1
56+
expert_parallel_degree: 8
5657
expert_tensor_parallel_degree: 1
5758

5859
checkpoint:
@@ -69,7 +70,16 @@ modules:
6970

7071
compile:
7172
enable: true
72-
components: ["loss"] # ["model", "loss"]
73+
components: ["model", "loss"] # ["model", "loss"]
74+
75+
primus_turbo:
76+
enable_primus_turbo: true
77+
use_turbo_mx_linear: false
78+
use_turbo_float8_linear: true
79+
enable_attention_float8: false
80+
use_classic_attention: true
81+
use_turbo_grouped_mm: true
82+
use_moe_fp8: false
7383

7484
# quantize:
7585
# linear:

examples/torchtitan/configs/MI355X/deepseek_v3_16b-pretrain.yaml

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ modules:
1313
model: deepseek_v3_16b.yaml
1414
overrides:
1515
profiling:
16-
enable_profiling: false
16+
enable_profiling: true
1717
save_traces_folder: "profile_trace"
1818
profile_freq: 10
1919
enable_memory_snapshot: false
2020
save_memory_snapshot_folder: "memory_snapshot"
2121

2222
metrics:
23-
log_freq: 10
23+
log_freq: 1
2424
disable_color_printing: false
2525
enable_tensorboard: false
2626
save_tb_folder: "tb"
@@ -38,11 +38,12 @@ modules:
3838
min_lr_factor: 0.1
3939

4040
training:
41+
debug_moe_force_load_balance: true
4142
local_batch_size: 4
4243
seq_len: 4096
4344
max_norm: 1.0 # grad norm clipping
44-
steps: 1000
45-
dataset: "c4" # supported datasets: c4_test (2K), c4 (177M)
45+
steps: 15
46+
dataset: "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4647

4748
parallelism:
4849
data_parallel_replicate_degree: 1
@@ -69,8 +70,16 @@ modules:
6970

7071
compile:
7172
enable: true
72-
components: ["loss"] # ["model", "loss"]
73+
components: ["model", "loss"] # ["model", "loss"]
7374

75+
primus_turbo:
76+
enable_primus_turbo: true
77+
use_turbo_mx_linear: false
78+
use_turbo_float8_linear: true
79+
enable_attention_float8: false
80+
use_turbo_grouped_mm: true
81+
use_moe_fp8: false
82+
7483
# quantize:
7584
# linear:
7685
# float8:

examples/torchtitan/configs/MI355X/deepseek_v3_671b-pretrain.yaml

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ modules:
1313
model: deepseek_v3_671b.yaml
1414
overrides:
1515
profiling:
16-
enable_profiling: false
16+
enable_profiling: true
1717
save_traces_folder: "profile_trace"
1818
profile_freq: 10
1919
enable_memory_snapshot: false
2020
save_memory_snapshot_folder: "memory_snapshot"
2121

2222
metrics:
23-
log_freq: 10
23+
log_freq: 1
2424
disable_color_printing: false
2525
enable_tensorboard: false
2626
save_tb_folder: "tb"
@@ -38,11 +38,12 @@ modules:
3838
min_lr_factor: 0.1
3939

4040
training:
41-
local_batch_size: 4
41+
debug_moe_force_load_balance: true
42+
local_batch_size: 16
4243
seq_len: 4096
4344
max_norm: 1.0 # grad norm clipping
44-
steps: 1000
45-
dataset: "c4" # supported datasets: c4_test (2K), c4 (177M)
45+
steps: 15
46+
dataset: "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4647

4748
parallelism:
4849
data_parallel_replicate_degree: 1
@@ -52,7 +53,7 @@ modules:
5253
enable_async_tensor_parallel: false
5354
pipeline_parallel_degree: 1
5455
pipeline_parallel_schedule: "Interleaved1F1B"
55-
expert_parallel_degree: 1
56+
expert_parallel_degree: 8
5657
expert_tensor_parallel_degree: 1
5758

5859
checkpoint:
@@ -69,7 +70,16 @@ modules:
6970

7071
compile:
7172
enable: true
72-
components: ["loss"] # ["model", "loss"]
73+
components: ["model", "loss"] # ["model", "loss"]
74+
75+
primus_turbo:
76+
enable_primus_turbo: true
77+
use_turbo_mx_linear: false
78+
use_turbo_float8_linear: true
79+
enable_attention_float8: false
80+
use_classic_attention: true
81+
use_turbo_grouped_mm: true
82+
use_moe_fp8: false
7383

7484
# quantize:
7585
# linear:

primus/backends/megatron/core/extensions/primus_turbo.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,16 @@
2929
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
3030
from megatron.core.utils import get_tensor_model_parallel_group_if_none
3131
from megatron.training.global_vars import get_args
32-
from primus_turbo.pytorch.core.float8 import (
32+
33+
try:
34+
from primus_turbo.pytorch.core.float8 import (
35+
Float8QuantConfig,
36+
ScalingGranularity,
37+
ScalingStrategy,
38+
check_fp8_support,
39+
)
40+
except ImportError:
41+
from primus_turbo.pytorch.core.low_precision import (
3342
Float8QuantConfig,
3443
ScalingGranularity,
3544
ScalingStrategy,

primus/backends/megatron/core/fp8_utils.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,22 @@
4242
from megatron.core import parallel_state
4343
from megatron.core.enums import Fp8Recipe
4444
from megatron.core.extensions.transformer_engine import TEDelayedScaling
45-
from primus_turbo.pytorch.core.float8 import ScalingGranularity
45+
try:
46+
from primus_turbo.pytorch.core.float8 import ScalingGranularity
47+
except ImportError:
48+
from primus_turbo.pytorch.core.low_precision import ScalingGranularity
49+
4650

4751
from primus.backends.megatron.core.extensions.primus_turbo import (
4852
PrimusTurboFloat8QuantConfig,
4953
)
5054

5155
def te_fp8_format_mapping(te_format):
52-
from primus_turbo.pytorch.core.float8 import Format as TurboFormat
56+
try:
57+
from primus_turbo.pytorch.core.float8 import Format as TurboFormat
58+
except ImportError:
59+
from primus_turbo.pytorch.core.low_precision import Format as TurboFormat
60+
# noqa: F811
5361
from transformer_engine.common.recipe import Format as TEFormat
5462

5563
format_mapping = {
@@ -194,7 +202,10 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool
194202
elif HAVE_TURBO:
195203
from megatron.core import parallel_state
196204
from megatron.core.enums import Fp8Recipe
197-
from primus_turbo.pytorch.core.float8 import ScalingGranularity
205+
try:
206+
from primus_turbo.pytorch.core.float8 import ScalingGranularity
207+
except ImportError:
208+
from primus_turbo.pytorch.core.low_precision import ScalingGranularity
198209

199210
from primus.backends.megatron.core.extensions.primus_turbo import (
200211
PrimusTurboFloat8QuantConfig,
@@ -234,10 +245,20 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool
234245
# fp8 training and this layer_no is in fp8
235246
import primus_turbo
236247

237-
if config.fp8 == "e4m3":
238-
fp8_format = primus_turbo.pytorch.core.float8.Format.E4M3
239-
elif config.fp8 == "hybrid":
240-
fp8_format = primus_turbo.pytorch.core.float8.Format.HYBRID
248+
# Pick the right Format enum once
249+
try:
250+
# Older API
251+
from primus_turbo.pytorch.core.float8 import Format as FP8Format
252+
except ImportError:
253+
# Newer API
254+
from primus_turbo.pytorch.core.low_precision import Format as FP8Format
255+
256+
fp8_str = config.fp8.lower()
257+
258+
if fp8_str == "e4m3":
259+
fp8_format = FP8Format.E4M3
260+
elif fp8_str == "hybrid":
261+
fp8_format = FP8Format.HYBRID
241262
else:
242263
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
243264

0 commit comments

Comments
 (0)