Skip to content

Commit d066874

Browse files
authored
fix remove_unused_columns (#4749)
1 parent c8bc461 commit d066874

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

swift/llm/argument/base_args/template_args.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,12 @@ def __post_init__(self):
6464
self.truncation_strategy = 'delete'
6565

6666
def get_template_kwargs(self):
67+
from ..train_args import TrainArguments
6768
truncation_strategy = self.truncation_strategy
6869
if truncation_strategy == 'delete':
6970
truncation_strategy = 'raise'
70-
remove_unused_columns = self.remove_unused_columns
71-
if hasattr(self, 'rlhf_type') and self.rlhf_type == 'grpo':
71+
remove_unused_columns = self.remove_unused_columns # from DataArguments
72+
if not isinstance(self, TrainArguments) or hasattr(self, 'rlhf_type') and self.rlhf_type == 'grpo':
7273
remove_unused_columns = True
7374
return {
7475
'default_system': self.system,

0 commit comments

Comments
 (0)