@@ -24,7 +24,7 @@ use anyhow::Context;
2424use candle:: { DType , Device } ;
2525use candle_nn:: VarBuilder ;
2626use nohash_hasher:: BuildNoHashHasher ;
27- use serde:: Deserialize ;
27+ use serde:: { de :: Deserializer , Deserialize } ;
2828use std:: collections:: HashMap ;
2929use std:: path:: Path ;
3030use text_embeddings_backend_core:: {
@@ -33,19 +33,58 @@ use text_embeddings_backend_core::{
3333
3434/// This enum is needed to be able to differentiate between jina models that also use
3535/// the `bert` model type and valid Bert models.
36- /// We use the `_name_or_path` field in the config to do so. This might not be robust in the long
37- /// run but is still better than the other options...
38- #[ derive( Debug , Clone , PartialEq , Deserialize ) ]
39- #[ serde( tag = "_name_or_path" ) ]
36+ #[ derive( Debug , Clone , PartialEq ) ]
4037pub enum BertConfigWrapper {
41- #[ serde( rename = "jinaai/jina-bert-implementation" ) ]
4238 JinaBert ( BertConfig ) ,
43- #[ serde( rename = "jinaai/jina-bert-v2-qk-post-norm" ) ]
4439 JinaCodeBert ( BertConfig ) ,
45- #[ serde( untagged) ]
4640 Bert ( BertConfig ) ,
4741}
4842
43+ /// Custom deserializer is required as we need to capture both whether the `_name_or_path` value
44+ /// is any of the JinaBERT alternatives, or alternatively to also support fine-tunes and re-uploads
45+ /// with Sentence Transformers, we also need to check the value for the `auto_map.AutoConfig`
46+ /// configuration file, and see if that points to the relevant remote code repositories on the Hub
47+ impl < ' de > Deserialize < ' de > for BertConfigWrapper {
48+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
49+ where
50+ D : Deserializer < ' de > ,
51+ {
52+ use serde:: de:: Error ;
53+
54+ #[ allow( unused_mut) ]
55+ let mut value = serde_json:: Value :: deserialize ( deserializer) ?;
56+
57+ let name_or_path = value
58+ . get ( "_name_or_path" )
59+ . and_then ( |v| v. as_str ( ) )
60+ . map ( ToString :: to_string)
61+ . unwrap_or_default ( ) ;
62+
63+ let auto_config = value
64+ . get ( "auto_map" )
65+ . and_then ( |v| v. get ( "AutoConfig" ) )
66+ . and_then ( |v| v. as_str ( ) )
67+ . map ( ToString :: to_string)
68+ . unwrap_or_default ( ) ;
69+
70+ let config = BertConfig :: deserialize ( value) . map_err ( Error :: custom) ?;
71+
72+ if name_or_path == "jinaai/jina-bert-implementation"
73+ || auto_config. contains ( "jinaai/jina-bert-implementation" )
74+ {
75+ // https://huggingface.co/jinaai/jina-bert-implementation
76+ Ok ( Self :: JinaBert ( config) )
77+ } else if name_or_path == "jinaai/jina-bert-v2-qk-post-norm"
78+ || auto_config. contains ( "jinaai/jina-bert-v2-qk-post-norm" )
79+ {
80+ // https://huggingface.co/jinaai/jina-bert-v2-qk-post-norm
81+ Ok ( Self :: JinaCodeBert ( config) )
82+ } else {
83+ Ok ( Self :: Bert ( config) )
84+ }
85+ }
86+ }
87+
4988#[ derive( Deserialize ) ]
5089#[ serde( tag = "model_type" , rename_all = "kebab-case" ) ]
5190enum Config {
0 commit comments