File tree Expand file tree Collapse file tree 2 files changed +32
-5
lines changed
src/axolotl/utils/schemas Expand file tree Collapse file tree 2 files changed +32
-5
lines changed Original file line number Diff line number Diff line change @@ -512,10 +512,17 @@ def check_sample_packing_w_rl(cls, data):
512
512
@model_validator (mode = "before" )
513
513
@classmethod
514
514
def hint_sample_packing_padding (cls , data ):
515
- if data .get ("sample_packing" ) and not data .get ("pad_to_sequence_len" ):
516
- LOG .warning (
517
- "`pad_to_sequence_len: true` is recommended when using sample_packing"
518
- )
515
+ if data .get ("sample_packing" ):
516
+ pad_to_sequence_len = data .get ("pad_to_sequence_len" )
517
+ if pad_to_sequence_len is False :
518
+ LOG .warning (
519
+ "`pad_to_sequence_len: true` is recommended when using sample_packing"
520
+ )
521
+ elif pad_to_sequence_len is None :
522
+ LOG .info (
523
+ "Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
524
+ )
525
+ data ["pad_to_sequence_len" ] = True
519
526
return data
520
527
521
528
@model_validator (mode = "before" )
Original file line number Diff line number Diff line change @@ -648,7 +648,7 @@ def test_packing(self, minimal_cfg):
648
648
DictDefault (
649
649
{
650
650
"sample_packing" : True ,
651
- "pad_to_sequence_len" : None ,
651
+ "pad_to_sequence_len" : False ,
652
652
"flash_attention" : True ,
653
653
}
654
654
)
@@ -662,6 +662,26 @@ def test_packing(self, minimal_cfg):
662
662
for record in self ._caplog .records
663
663
)
664
664
665
+ def test_packing_autoset (self , minimal_cfg ):
666
+ cfg = (
667
+ DictDefault (
668
+ {
669
+ "sample_packing" : True ,
670
+ "pad_to_sequence_len" : None ,
671
+ "flash_attention" : True ,
672
+ }
673
+ )
674
+ | minimal_cfg
675
+ )
676
+ with self ._caplog .at_level (logging .INFO ):
677
+ cfg = validate_config (cfg )
678
+ assert any (
679
+ "Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
680
+ in record .message
681
+ for record in self ._caplog .records
682
+ )
683
+ assert cfg .pad_to_sequence_len is True
684
+
665
685
def test_merge_lora_no_bf16_fail (self , minimal_cfg ):
666
686
"""
667
687
This is assumed to be run on a CPU machine, so bf16 is not supported.
You can’t perform that action at this time.
0 commit comments