@@ -52,7 +52,7 @@ def test_init(self, mock_sentence_transformer):
5252 embedder = SentenceTransformerEmbedder (self .config )
5353
5454 # Verify model initialization
55- mock_sentence_transformer .assert_called_once_with ("all-MiniLM-L6-v2" ,
55+ mock_sentence_transformer .assert_called_once_with (model_name_or_path = "all-MiniLM-L6-v2" ,
5656 device = "cpu" )
5757 self .assertEqual (embedder .model , self .mock_model )
5858 self .assertEqual (embedder ._cache_size , 100 )
@@ -74,7 +74,7 @@ def test_init_with_path(self, mock_sentence_transformer):
7474
7575 embedder = SentenceTransformerEmbedder (config )
7676
77- mock_sentence_transformer .assert_called_once_with ("all-MiniLM-L6-v2" ,
77+ mock_sentence_transformer .assert_called_once_with (model_name_or_path = "all-MiniLM-L6-v2" ,
7878 device = "cuda" )
7979 self .assertEqual (embedder ._cache_size , 50 )
8080 self .assertEqual (embedder ._memory_threshold , 0.9 )
@@ -102,9 +102,41 @@ def test_init_cache_disabled_logging(self, mock_log,
102102
103103 # Verify embedder was still initialized properly
104104 self .assertEqual (embedder ._cache_size , 0 )
105- mock_sentence_transformer .assert_called_once_with ("all-MiniLM-L6-v2" ,
105+ mock_sentence_transformer .assert_called_once_with (model_name_or_path = "all-MiniLM-L6-v2" ,
106106 device = "cpu" )
107107
108+ @patch ('ragl.embed.sentencetransformer.SentenceTransformer' )
109+ @patch ('ragl.embed.sentencetransformer._LOG' )
110+ def test_init_kwargs_device_warning (self , mock_log ,
111+ mock_sentence_transformer ):
112+ """Test that device in init_kwargs triggers warning and is ignored."""
113+ mock_sentence_transformer .return_value = self .mock_model
114+
115+ config = SentenceTransformerConfig (
116+ model_name_or_path = 'all-MiniLM-L6-v2' ,
117+ cache_maxsize = 100 ,
118+ memory_threshold = 0.8 ,
119+ auto_clear_cache = True ,
120+ device = "cpu" ,
121+ init_kwargs = {"device" : "cuda" , "trust_remote_code" : True }
122+ )
123+
124+ embedder = SentenceTransformerEmbedder (config )
125+
126+ # Verify warning was logged
127+ mock_log .warning .assert_called_once_with (
128+ 'Ignoring device setting in init_kwargs (%s); '
129+ 'use config.device (%s) instead' ,
130+ 'cuda' , 'cpu'
131+ )
132+
133+ # Verify SentenceTransformer was called with config.device, not init_kwargs device
134+ mock_sentence_transformer .assert_called_once_with (
135+ model_name_or_path = "all-MiniLM-L6-v2" ,
136+ device = "cpu" , # Should use config.device
137+ trust_remote_code = True # Other init_kwargs should still be passed
138+ )
139+
108140 @patch ('ragl.embed.sentencetransformer.SentenceTransformer' )
109141 def test_dimensions_property (self , mock_sentence_transformer ):
110142 """Test dimensions property."""
@@ -123,7 +155,7 @@ def test_dimensions_property_assertion(self, mock_sentence_transformer):
123155 self .mock_model .get_sentence_embedding_dimension .return_value = "invalid"
124156
125157 # Now the assertion should fail during initialization
126- with self .assertRaises (AssertionError ):
158+ with self .assertRaises (ValueError ):
127159 embedder = SentenceTransformerEmbedder (self .config )
128160
129161 @patch ('ragl.embed.sentencetransformer.SentenceTransformer' )
0 commit comments