Skip to content

Commit 4c58dcb

Browse files
committed
add electra freeze option
1 parent 86044af commit 4c58dcb

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

chebai/models/electra.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def __init__(
224224
config: Optional[Dict[str, Any]] = None,
225225
pretrained_checkpoint: Optional[str] = None,
226226
load_prefix: Optional[str] = None,
227+
freeze_electra: bool = False,
227228
**kwargs: Any,
228229
):
229230
# Remove this property in order to prevent it from being stored as a
@@ -262,6 +263,10 @@ def __init__(
262263
else:
263264
self.electra = ElectraModel(config=self.config)
264265

266+
if freeze_electra:
267+
for param in self.electra.parameters():
268+
param.requires_grad = False
269+
265270
def _process_for_loss(
266271
self,
267272
model_output: Dict[str, Tensor],

0 commit comments

Comments
 (0)