@@ -528,6 +528,7 @@ def build(
528528 pooling : str = "avg" ,
529529 l2_normalize : bool = True ,
530530 t : float = 1.0 ,
531+ trust_remote_code : bool = False ,
531532 ** hf_kwargs ,
532533 ):
533534 """
@@ -544,6 +545,7 @@ def build(
544545 pooling: Pooling strategy ('avg', 'cls', 'last', etc.)
545546 l2_normalize: Whether to L2 normalize embeddings
546547 t: Temperature for scaling similarity scores
548+ trust_remote_code: Whether to trust remote code
547549 **hf_kwargs: Additional arguments passed to model loading
548550 """
549551
@@ -575,7 +577,7 @@ def build(
575577 # Load model locally or from hub using selected model class
576578 if os .path .isdir (model_name_or_path ):
577579 if share_encoder :
578- lm_q = ModelClass .from_pretrained (model_name_or_path , trust_remote_code = True , ** hf_kwargs )
580+ lm_q = ModelClass .from_pretrained (model_name_or_path , trust_remote_code = trust_remote_code , ** hf_kwargs )
579581 lm_p = lm_q
580582 else :
581583 _qry_model_path = os .path .join (model_name_or_path , "query_model" )
@@ -585,8 +587,8 @@ def build(
585587 _qry_model_path = model_name_or_path
586588 _psg_model_path = model_name_or_path
587589
588- lm_q = ModelClass .from_pretrained (_qry_model_path , trust_remote_code = True , ** hf_kwargs )
589- lm_p = ModelClass .from_pretrained (_psg_model_path , trust_remote_code = True , ** hf_kwargs )
590+ lm_q = ModelClass .from_pretrained (_qry_model_path , trust_remote_code = trust_remote_code , ** hf_kwargs )
591+ lm_p = ModelClass .from_pretrained (_psg_model_path , trust_remote_code = trust_remote_code , ** hf_kwargs )
590592 else :
591593 # Load from hub
592594 lm_q = ModelClass .from_pretrained (model_name_or_path , ** hf_kwargs )
0 commit comments