Skip to content

Commit c253e7f

Browse files
authored
Modifications to allow for using ModernBERT as a base encoder. (Machine-Learning-for-Medical-Language#226)
1 parent 84e784d commit c253e7f

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

src/cnlpt/CnlpModelForClassification.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,16 @@ def __init__(
240240
if encoder_name.startswith("distilbert"):
241241
self.hidden_dropout_prob = self.encoder_config["dropout"]
242242
self.hidden_size = self.encoder_config["dim"]
243+
elif self.encoder_config["model_type"] == "modernbert":
244+
self.hidden_size = self.encoder_config["hidden_size"]
245+
# downstream uses hidden dropout prob for additional layers, modernbert splits into different dropouts for different
246+
# parts of the encoder -- mlp dropout is probably generally good
247+
self.hidden_dropout_prob = self.encoder_config["mlp_dropout"]
248+
# don't need these in my code but keep them around just in case
249+
self.attention_dropout = self.encoder_config["attention_dropout"]
250+
self.embedding_dropout = self.encoder_config["embedding_dropout"]
251+
self.mlp_dropout = self.encoder_config["mlp_dropout"]
252+
self.classifier_dropout = self.encoder_config["classifier_dropout"]
243253
else:
244254
try:
245255
self.hidden_dropout_prob = self.encoder_config["hidden_dropout_prob"]
@@ -285,6 +295,7 @@ def __init__(
285295
config.encoder_config = encoder_config.to_dict()
286296
encoder_model = AutoModel.from_config(encoder_config)
287297
self.encoder = encoder_model.from_pretrained(config.encoder_name)
298+
288299
# part of the motivation for leaving this
289300
# logic alone for character level models is that
290301
# at the time of writing, CANINE and Flair are the only game in town.
@@ -299,10 +310,10 @@ def __init__(
299310
# It also will be used as the canonical order of returning results/logits
300311
self.tasks = config.finetuning_task
301312

302-
if config.layer > len(encoder_model.encoder.layer):
313+
if config.layer > self.num_layers:
303314
raise ValueError(
304315
"The layer specified (%d) is too big for the specified encoder which has %d layers"
305-
% (config.layer, len(encoder_model.encoder.layer))
316+
% (config.layer, self.num_layers)
306317
)
307318

308319
if freeze > 0:
@@ -358,6 +369,13 @@ def __init__(
358369

359370
# self.init_weights()
360371

372+
@property
373+
def num_layers(self):
374+
if self.encoder.config.model_type == "modernbert":
375+
return len(self.encoder.base_model.layers)
376+
else:
377+
return len(self.encoder.encoder.layer)
378+
361379
def predict_relations_with_previous_logits(
362380
self, features: torch.Tensor, logits: torch.Tensor
363381
) -> torch.Tensor:

0 commit comments

Comments
 (0)