Skip to content

Commit 9741312

Browse files
committed
add fix for remove columns collator which fails with streaming
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
1 parent 2dbff48 commit 9741312

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from fms_acceleration import AccelerationPlugin
2121
from peft import LoraConfig
2222
from transformers import DataCollatorForSeq2Seq, TrainingArguments
23+
from transformers.trainer_utils import RemoveColumnsCollator
2324
from trl import DataCollatorForCompletionOnlyLM # pylint: disable=import-error
2425
import torch
2526

@@ -70,8 +71,14 @@ def _collator_check(collate_fn):
7071
# "The padding-free plugin currently only works with a
7172
# `DataCollatorForSeq2Seq` collate_fn,
7273
# otherwise the collation can be unreliable"
74+
if isinstance(collate_fn, RemoveColumnsCollator):
75+
collate_fn = collate_fn.data_collator
7376
return isinstance(
74-
collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM)
77+
collate_fn,
78+
(
79+
DataCollatorForSeq2Seq,
80+
DataCollatorForCompletionOnlyLM,
81+
),
7582
)
7683

7784
# This check is done here to only patch the attention forward
@@ -97,6 +104,14 @@ def _collator_check(collate_fn):
97104

98105
def _collator_replacement_builder(collate_fn):
99106

107+
# in case of remove columns collator the actual collate
108+
# function is wrapped inside
109+
if isinstance(collate_fn, RemoveColumnsCollator):
110+
actual_collate_fn = collate_fn.data_collator
111+
replacement = _collator_replacement_builder(actual_collate_fn)
112+
collate_fn.data_collator = replacement
113+
return collate_fn
114+
100115
# in this case, replace seq2seq with flattening collator
101116
if isinstance(collate_fn, DataCollatorForSeq2Seq):
102117
return DataCollatorWithFlattening()

0 commit comments

Comments
 (0)