Skip to content

Commit 8c05e58

Browse files
authored
add fix for remove columns collator which fails with streaming (#135)
Signed-off-by: Dushyant Behl <[email protected]>
1 parent 2990230 commit 8c05e58

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from fms_acceleration import AccelerationPlugin
2121
from peft import LoraConfig
2222
from transformers import DataCollatorForSeq2Seq, TrainingArguments
23+
from transformers.trainer_utils import RemoveColumnsCollator
2324
from trl import ( # pylint: disable=import-error, no-name-in-module
2425
DataCollatorForCompletionOnlyLM,
2526
)
@@ -72,6 +73,8 @@ def _collator_check(collate_fn):
7273
# "The padding-free plugin currently only works with a
7374
# `DataCollatorForSeq2Seq` collate_fn,
7475
# otherwise the collation can be unreliable"
76+
if isinstance(collate_fn, RemoveColumnsCollator):
77+
collate_fn = collate_fn.data_collator
7578
return isinstance(
7679
collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM)
7780
)
@@ -99,6 +102,14 @@ def _collator_check(collate_fn):
99102

100103
def _collator_replacement_builder(collate_fn):
101104

105+
# in case of remove columns collator the actual collate
106+
# function is wrapped inside
107+
if isinstance(collate_fn, RemoveColumnsCollator):
108+
actual_collate_fn = collate_fn.data_collator
109+
replacement = _collator_replacement_builder(actual_collate_fn)
110+
collate_fn.data_collator = replacement
111+
return collate_fn
112+
102113
# in this case, replace seq2seq with flattening collator
103114
if isinstance(collate_fn, DataCollatorForSeq2Seq):
104115
return DataCollatorWithFlattening()

0 commit comments

Comments
 (0)