We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 86044af commit 4c58dcbCopy full SHA for 4c58dcb
chebai/models/electra.py
@@ -224,6 +224,7 @@ def __init__(
224
config: Optional[Dict[str, Any]] = None,
225
pretrained_checkpoint: Optional[str] = None,
226
load_prefix: Optional[str] = None,
227
+ freeze_electra: bool = False,
228
**kwargs: Any,
229
):
230
# Remove this property in order to prevent it from being stored as a
@@ -262,6 +263,10 @@ def __init__(
262
263
else:
264
self.electra = ElectraModel(config=self.config)
265
266
+ if freeze_electra:
267
+ for param in self.electra.parameters():
268
+ param.requires_grad = False
269
+
270
def _process_for_loss(
271
self,
272
model_output: Dict[str, Tensor],
0 commit comments