Skip to content

Commit 47e0e71

Browse files
djsaundewinglian
andauthored
don't sort multipack sampler (axolotl-ai-cloud#2657)
* don't sort multipack sampler * increased packing efficiency increases loss --------- Co-authored-by: Wing Lian <[email protected]>
1 parent 0f35871 commit 47e0e71

File tree

3 files changed

+5
-8
lines changed

3 files changed

+5
-8
lines changed

src/axolotl/utils/samplers/multipack.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,11 @@ def pack_group(
7878
Returns:
7979
List of bins, where each bin contains indices of sequences assigned to it
8080
"""
81-
# Get sorting indices and sort lengths in descending order
82-
indices = np.argsort(sequence_lengths)[::-1]
83-
sorted_lengths = sequence_lengths[indices]
84-
8581
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
8682
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
8783

88-
for seq_id, size in enumerate(sorted_lengths):
89-
global_idx = indices[seq_id] + group_offset
84+
for seq_id, size in enumerate(sequence_lengths):
85+
global_idx = seq_id + group_offset
9086

9187
# Try to place sequence in existing bins
9288
add_new_bin = True

tests/e2e/integrations/test_kd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_llama_kd(self, temp_dir, kd_min_cfg):
9090
train(cfg=cfg, dataset_meta=dataset_meta)
9191
assert (Path(temp_dir) / "model.safetensors").exists()
9292
check_tensorboard(
93-
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
93+
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
9494
)
9595

9696
@pytest.mark.parametrize(
@@ -121,5 +121,5 @@ def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit):
121121
train(cfg=cfg, dataset_meta=dataset_meta)
122122
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
123123
check_tensorboard(
124-
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
124+
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
125125
)

tests/test_packed_batch_sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,4 @@ def test_packing(
106106

107107
original_idxs = set(range(len(train_dataset)))
108108
assert original_idxs == set(batch_idxs)
109+
assert len(batch_idxs) == len(set(batch_idxs))

0 commit comments

Comments
 (0)