2020from fms_acceleration import AccelerationPlugin
2121from peft import LoraConfig
2222from transformers import DataCollatorForSeq2Seq , TrainingArguments
23+ from transformers .trainer_utils import RemoveColumnsCollator
2324from trl import DataCollatorForCompletionOnlyLM # pylint: disable=import-error
2425import torch
2526
@@ -70,8 +71,14 @@ def _collator_check(collate_fn):
7071 # "The padding-free plugin currently only works with a
7172 # `DataCollatorForSeq2Seq` collate_fn,
7273 # otherwise the collation can be unreliable"
74+ if isinstance (collate_fn , RemoveColumnsCollator ):
75+ collate_fn = collate_fn .data_collator
7376 return isinstance (
74- collate_fn , (DataCollatorForSeq2Seq , DataCollatorForCompletionOnlyLM )
77+ collate_fn ,
78+ (
79+ DataCollatorForSeq2Seq ,
80+ DataCollatorForCompletionOnlyLM ,
81+ ),
7582 )
7683
7784 # This check is done here to only patch the attention forward
@@ -97,6 +104,14 @@ def _collator_check(collate_fn):
97104
98105 def _collator_replacement_builder (collate_fn ):
99106
107+ # in case of remove columns collator the actual collate
108+ # function is wrapped inside
109+ if isinstance (collate_fn , RemoveColumnsCollator ):
110+ actual_collate_fn = collate_fn .data_collator
111+ replacement = _collator_replacement_builder (actual_collate_fn )
112+ collate_fn .data_collator = replacement
113+ return collate_fn
114+
100115 # in this case, replace seq2seq with flattening collator
101116 if isinstance (collate_fn , DataCollatorForSeq2Seq ):
102117 return DataCollatorWithFlattening ()
0 commit comments