Skip to content

Commit 11be774

Browse files
Fix support for model_init_kwargs in GKD/GOLD when passed as CLI JSON string (#5266)
1 parent e5ea2c4 commit 11be774

File tree

2 files changed

+2
-6
lines changed

2 files changed

+2
-6
lines changed

trl/experimental/gkd/gkd_config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from dataclasses import dataclass, field
1616
from typing import Any
1717

18-
from transformers import TrainingArguments
19-
2018
from ...trainer.sft_config import SFTConfig
2119

2220

@@ -52,7 +50,7 @@ class GKDConfig(SFTConfig):
5250
teacher-generated output).
5351
"""
5452

55-
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
53+
_VALID_DICT_FIELDS = SFTConfig._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
5654

5755
temperature: float = field(
5856
default=0.9,

trl/experimental/gold/gold_config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from dataclasses import dataclass, field
1616
from typing import Any
1717

18-
from transformers import TrainingArguments
19-
2018
from ...trainer.sft_config import SFTConfig
2119

2220

@@ -94,7 +92,7 @@ class GOLDConfig(SFTConfig):
9492
low, but waking the engine adds host–device transfer latency.
9593
"""
9694

97-
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
95+
_VALID_DICT_FIELDS = SFTConfig._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
9896

9997
# Parameters whose default values are overridden from TrainingArguments
10098
learning_rate: float = field(

0 commit comments

Comments
 (0)