Skip to content

Commit bcb59c7

Browse files
authored
automatically set pad_to_sequence_len when use packing (axolotl-ai-cloud#2607)
* automatically set pad_to_sequence_len when use packing * update tests
1 parent 6a3e6f8 commit bcb59c7

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

src/axolotl/utils/schemas/config.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,10 +512,17 @@ def check_sample_packing_w_rl(cls, data):
512512
@model_validator(mode="before")
513513
@classmethod
514514
def hint_sample_packing_padding(cls, data):
515-
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
516-
LOG.warning(
517-
"`pad_to_sequence_len: true` is recommended when using sample_packing"
518-
)
515+
if data.get("sample_packing"):
516+
pad_to_sequence_len = data.get("pad_to_sequence_len")
517+
if pad_to_sequence_len is False:
518+
LOG.warning(
519+
"`pad_to_sequence_len: true` is recommended when using sample_packing"
520+
)
521+
elif pad_to_sequence_len is None:
522+
LOG.info(
523+
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
524+
)
525+
data["pad_to_sequence_len"] = True
519526
return data
520527

521528
@model_validator(mode="before")

tests/patched/test_validation.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def test_packing(self, minimal_cfg):
648648
DictDefault(
649649
{
650650
"sample_packing": True,
651-
"pad_to_sequence_len": None,
651+
"pad_to_sequence_len": False,
652652
"flash_attention": True,
653653
}
654654
)
@@ -662,6 +662,26 @@ def test_packing(self, minimal_cfg):
662662
for record in self._caplog.records
663663
)
664664

665+
def test_packing_autoset(self, minimal_cfg):
666+
cfg = (
667+
DictDefault(
668+
{
669+
"sample_packing": True,
670+
"pad_to_sequence_len": None,
671+
"flash_attention": True,
672+
}
673+
)
674+
| minimal_cfg
675+
)
676+
with self._caplog.at_level(logging.INFO):
677+
cfg = validate_config(cfg)
678+
assert any(
679+
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
680+
in record.message
681+
for record in self._caplog.records
682+
)
683+
assert cfg.pad_to_sequence_len is True
684+
665685
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
666686
"""
667687
This is assumed to be run on a CPU machine, so bf16 is not supported.

0 commit comments

Comments
 (0)