Skip to content

Commit 801f63e

Browse files
authored
fix: surface trust_remote_code (#1139)
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 5b25e73 commit 801f63e

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

nemo_automodel/components/models/biencoder/llama_bidirectional_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ def build(
528528
pooling: str = "avg",
529529
l2_normalize: bool = True,
530530
t: float = 1.0,
531+
trust_remote_code: bool = False,
531532
**hf_kwargs,
532533
):
533534
"""
@@ -544,6 +545,7 @@ def build(
544545
pooling: Pooling strategy ('avg', 'cls', 'last', etc.)
545546
l2_normalize: Whether to L2 normalize embeddings
546547
t: Temperature for scaling similarity scores
548+
trust_remote_code: Whether to trust remote code
547549
**hf_kwargs: Additional arguments passed to model loading
548550
"""
549551

@@ -575,7 +577,7 @@ def build(
575577
# Load model locally or from hub using selected model class
576578
if os.path.isdir(model_name_or_path):
577579
if share_encoder:
578-
lm_q = ModelClass.from_pretrained(model_name_or_path, trust_remote_code=True, **hf_kwargs)
580+
lm_q = ModelClass.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **hf_kwargs)
579581
lm_p = lm_q
580582
else:
581583
_qry_model_path = os.path.join(model_name_or_path, "query_model")
@@ -585,8 +587,8 @@ def build(
585587
_qry_model_path = model_name_or_path
586588
_psg_model_path = model_name_or_path
587589

588-
lm_q = ModelClass.from_pretrained(_qry_model_path, trust_remote_code=True, **hf_kwargs)
589-
lm_p = ModelClass.from_pretrained(_psg_model_path, trust_remote_code=True, **hf_kwargs)
590+
lm_q = ModelClass.from_pretrained(_qry_model_path, trust_remote_code=trust_remote_code, **hf_kwargs)
591+
lm_p = ModelClass.from_pretrained(_psg_model_path, trust_remote_code=trust_remote_code, **hf_kwargs)
590592
else:
591593
# Load from hub
592594
lm_q = ModelClass.from_pretrained(model_name_or_path, **hf_kwargs)

0 commit comments

Comments
 (0)