Skip to content

Commit da17c7c

Browse files
authored
fix: use dp_world_size instead of world_size for batch_size with tensor parallelism (#3462) [skip ci]
1 parent cada93c commit da17c7c

File tree

3 files changed

+61
-11
lines changed

3 files changed

+61
-11
lines changed

src/axolotl/utils/config/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@ def normalize_config(cfg):
119119
if cfg.world_size != 1:
120120
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
121121
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
122-
effective_world_size = cfg.world_size // (cfg.context_parallel_size or 1)
122+
effective_world_size = (
123+
cfg.world_size
124+
// (cfg.context_parallel_size or 1)
125+
// (cfg.tensor_parallel_size or 1)
126+
)
123127
cfg.batch_size = cfg.batch_size * effective_world_size
124128

125129
if not cfg.use_ray:

src/axolotl/utils/trainer.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
457457
- 1
458458
)
459459
* cfg.num_epochs
460-
* cfg.tensor_parallel_size
461460
)
462461
LOG.debug(
463462
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
@@ -496,9 +495,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
496495
LOG.debug(f"data_loader_len: {data_loader_len}")
497496
# FIXME: is there a bug here somewhere? the total num steps depends
498497
# on the agreed on value for sample_packing_eff_est
499-
total_num_steps = int(
500-
math.floor(data_loader_len * cfg.num_epochs * cfg.tensor_parallel_size)
501-
)
498+
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
502499
if cfg.dataloader_drop_last:
503500
# drop the last batch for each epoch
504501
total_num_steps -= int(math.ceil(cfg.num_epochs))
@@ -519,12 +516,7 @@ def calc_sample_packing_eff_est(estimates: List[float]):
519516
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
520517
else:
521518
total_num_steps = int(
522-
math.ceil(
523-
len(train_dataset)
524-
* cfg.num_epochs
525-
* cfg.tensor_parallel_size
526-
/ cfg.batch_size
527-
)
519+
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
528520
)
529521
LOG.debug(f"total_num_steps: {total_num_steps}")
530522
return total_num_steps
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Tests for batch_size calculation with tensor parallelism."""
2+
3+
from unittest.mock import patch
4+
5+
import addict
6+
import pytest
7+
from axolotl.utils.config import normalize_config, validate_config
8+
from axolotl.utils.dict import DictDefault
9+
10+
11+
@pytest.fixture(name="tp_base_cfg")
12+
def fixture_tp_base_cfg(min_base_cfg):
13+
return (
14+
DictDefault(
15+
micro_batch_size=2,
16+
gradient_accumulation_steps=4,
17+
sequence_len=2048,
18+
num_epochs=1,
19+
)
20+
| min_base_cfg
21+
)
22+
23+
24+
class TestTensorParallelBatchSize:
25+
"""Verify batch_size scales by effective dp world_size when using tensor parallelism."""
26+
27+
@pytest.mark.parametrize(
28+
"world_size, tensor_parallel_size, expected_batch_size",
29+
[
30+
(4, 1, 32), # no TP: 2*4*4 = 32
31+
(4, 2, 16), # TP=2: 2*4*(4//2) = 16
32+
(4, 4, 8), # TP=4: 2*4*(4//4) = 8
33+
(2, 2, 8), # TP=ws: 2*4*(2//2) = 8 (no scaling)
34+
],
35+
)
36+
def test_batch_size_with_tensor_parallelism(
37+
self,
38+
tp_base_cfg,
39+
monkeypatch,
40+
world_size,
41+
tensor_parallel_size,
42+
expected_batch_size,
43+
):
44+
monkeypatch.setenv("WORLD_SIZE", str(world_size))
45+
tp_base_cfg["tensor_parallel_size"] = tensor_parallel_size
46+
cfg = validate_config(tp_base_cfg)
47+
# Mock load_model_config to avoid downloading the model and to bypass
48+
# the tie_word_embeddings validation that blocks TP > 1.
49+
with patch(
50+
"axolotl.utils.config.load_model_config",
51+
return_value=addict.Dict({"model_type": "llama"}),
52+
):
53+
normalize_config(cfg)
54+
assert cfg.batch_size == expected_batch_size

0 commit comments

Comments
 (0)