|
20 | 20 | from fms_acceleration import AccelerationPlugin |
21 | 21 | from peft import LoraConfig |
22 | 22 | from transformers import DataCollatorForSeq2Seq, TrainingArguments |
| 23 | +from transformers.trainer_utils import RemoveColumnsCollator |
23 | 24 | from trl import ( # pylint: disable=import-error, no-name-in-module |
24 | 25 | DataCollatorForCompletionOnlyLM, |
25 | 26 | ) |
@@ -72,6 +73,8 @@ def _collator_check(collate_fn): |
72 | 73 | # "The padding-free plugin currently only works with a |
73 | 74 | # `DataCollatorForSeq2Seq` collate_fn, |
74 | 75 | # otherwise the collation can be unreliable" |
| 76 | + if isinstance(collate_fn, RemoveColumnsCollator): |
| 77 | + collate_fn = collate_fn.data_collator |
75 | 78 | return isinstance( |
76 | 79 | collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM) |
77 | 80 | ) |
@@ -99,6 +102,14 @@ def _collator_check(collate_fn): |
99 | 102 |
|
100 | 103 | def _collator_replacement_builder(collate_fn): |
101 | 104 |
|
| 105 | + # in case of remove columns collator the actual collate |
| 106 | + # function is wrapped inside |
| 107 | + if isinstance(collate_fn, RemoveColumnsCollator): |
| 108 | + actual_collate_fn = collate_fn.data_collator |
| 109 | + replacement = _collator_replacement_builder(actual_collate_fn) |
| 110 | + collate_fn.data_collator = replacement |
| 111 | + return collate_fn |
| 112 | + |
102 | 113 | # in this case, replace seq2seq with flattening collator |
103 | 114 | if isinstance(collate_fn, DataCollatorForSeq2Seq): |
104 | 115 | return DataCollatorWithFlattening() |
|
0 commit comments