Skip to content

Commit 8d6dcca

Browse files
SentenceTransformerEmbedder.__init__() kwargs pass-through (#27)
Added: SentenceTransformerConfig init_kwargs attribute/parameter Changed: SentenceTransformerEmbedder passes config.init_kwargs attribute through to SentenceTransformer.__init__()
1 parent b8b29e6 commit 8d6dcca

File tree

4 files changed

+59
-11
lines changed

4 files changed

+59
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ exclude = ["tests", "work"]
3232

3333
[project]
3434
name = "ragl"
35-
version = "0.10.0"
35+
version = "0.10.1"
3636
dependencies = [
3737
"bleach",
3838
"numpy",

ragl/config.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import logging
2525
import re
26-
from dataclasses import dataclass
26+
from dataclasses import dataclass, field
2727
from typing import Any
2828

2929
from ragl.exceptions import ConfigurationError
@@ -104,7 +104,8 @@ class SentenceTransformerConfig(EmbedderConfig):
104104
to a model on disk.
105105
106106
cache_maxsize:
107-
Maximum number of embeddings to cache in memory.
107+
Maximum number of entries to cache in memory. Set to
108+
0 to disable caching.
108109
device:
109110
Device to use for embedding.
110111
auto_clear_cache:
@@ -116,14 +117,18 @@ class SentenceTransformerConfig(EmbedderConfig):
116117
Threshold for memory usage before cleaning up cache.
117118
This is a float between 0.0 and 1.0, where 1.0 means
118119
100% memory usage.
120+
init_kwargs:
121+
Additional keyword arguments to pass to the
122+
SentenceTransformer constructor.
119123
"""
120124

121125
model_name_or_path: str = 'all-mpnet-base-v2'
122-
cache_maxsize: int = 10_000 # set this to 0 to disable caching
126+
cache_maxsize: int = 10_000
123127
device: str | None = None
124128
auto_clear_cache: bool = True
125129
show_progress: bool = False
126130
memory_threshold: float = 0.9
131+
init_kwargs: dict[str, Any] = field(default_factory=dict)
127132

128133
def __post_init__(self) -> None:
129134
"""Validate configuration after initialization."""

ragl/embed/sentencetransformer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ class SentenceTransformerEmbedder:
9191
def dimensions(self) -> int:
9292
"""Retrieve the embedding dimension count."""
9393
dimensions = self.model.get_sentence_embedding_dimension()
94-
assert isinstance(dimensions, int)
94+
if not isinstance(dimensions, int) or dimensions <= 0:
95+
raise ValueError('Invalid embedding dimensions '
96+
'retrieved from model')
9597
return dimensions
9698

9799
def __init__(self, config: SentenceTransformerConfig) -> None:
@@ -111,11 +113,20 @@ def __init__(self, config: SentenceTransformerConfig) -> None:
111113
_LOG.info('Cache disabled (maxsize=%d)', config.cache_maxsize)
112114

113115
model_path = Path(config.model_name_or_path)
114-
self.model = SentenceTransformer(str(model_path), device=config.device)
116+
kwargs_device = config.init_kwargs.pop('device', None)
117+
if kwargs_device is not None:
118+
_LOG.warning('Ignoring device setting in init_kwargs (%s); '
119+
'use config.device (%s) instead',
120+
kwargs_device, config.device)
121+
self.model = SentenceTransformer(
122+
model_name_or_path=str(model_path),
123+
device=config.device,
124+
**config.init_kwargs,
125+
)
115126
self._cache_size = config.cache_maxsize
116-
self._memory_threshold = config.memory_threshold
117127
self._auto_cleanup = config.auto_clear_cache
118128
self._show_progress = config.show_progress
129+
self._memory_threshold = config.memory_threshold
119130
self._embed_cached = lru_cache(self._cache_size)(self._embed_impl)
120131
_LOG.debug('Embedder initialized: dims=%d, cache_size=%d, device=%s',
121132
self.dimensions, self._cache_size, config.device)

tests/functional/ragl/embed/test_sentencetransformer.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_init(self, mock_sentence_transformer):
5252
embedder = SentenceTransformerEmbedder(self.config)
5353

5454
# Verify model initialization
55-
mock_sentence_transformer.assert_called_once_with("all-MiniLM-L6-v2",
55+
mock_sentence_transformer.assert_called_once_with(model_name_or_path="all-MiniLM-L6-v2",
5656
device="cpu")
5757
self.assertEqual(embedder.model, self.mock_model)
5858
self.assertEqual(embedder._cache_size, 100)
@@ -74,7 +74,7 @@ def test_init_with_path(self, mock_sentence_transformer):
7474

7575
embedder = SentenceTransformerEmbedder(config)
7676

77-
mock_sentence_transformer.assert_called_once_with("all-MiniLM-L6-v2",
77+
mock_sentence_transformer.assert_called_once_with(model_name_or_path="all-MiniLM-L6-v2",
7878
device="cuda")
7979
self.assertEqual(embedder._cache_size, 50)
8080
self.assertEqual(embedder._memory_threshold, 0.9)
@@ -102,9 +102,41 @@ def test_init_cache_disabled_logging(self, mock_log,
102102

103103
# Verify embedder was still initialized properly
104104
self.assertEqual(embedder._cache_size, 0)
105-
mock_sentence_transformer.assert_called_once_with("all-MiniLM-L6-v2",
105+
mock_sentence_transformer.assert_called_once_with(model_name_or_path="all-MiniLM-L6-v2",
106106
device="cpu")
107107

108+
@patch('ragl.embed.sentencetransformer.SentenceTransformer')
109+
@patch('ragl.embed.sentencetransformer._LOG')
110+
def test_init_kwargs_device_warning(self, mock_log,
111+
mock_sentence_transformer):
112+
"""Test that device in init_kwargs triggers warning and is ignored."""
113+
mock_sentence_transformer.return_value = self.mock_model
114+
115+
config = SentenceTransformerConfig(
116+
model_name_or_path='all-MiniLM-L6-v2',
117+
cache_maxsize=100,
118+
memory_threshold=0.8,
119+
auto_clear_cache=True,
120+
device="cpu",
121+
init_kwargs={"device": "cuda", "trust_remote_code": True}
122+
)
123+
124+
embedder = SentenceTransformerEmbedder(config)
125+
126+
# Verify warning was logged
127+
mock_log.warning.assert_called_once_with(
128+
'Ignoring device setting in init_kwargs (%s); '
129+
'use config.device (%s) instead',
130+
'cuda', 'cpu'
131+
)
132+
133+
# Verify SentenceTransformer was called with config.device, not init_kwargs device
134+
mock_sentence_transformer.assert_called_once_with(
135+
model_name_or_path="all-MiniLM-L6-v2",
136+
device="cpu", # Should use config.device
137+
trust_remote_code=True # Other init_kwargs should still be passed
138+
)
139+
108140
@patch('ragl.embed.sentencetransformer.SentenceTransformer')
109141
def test_dimensions_property(self, mock_sentence_transformer):
110142
"""Test dimensions property."""
@@ -123,7 +155,7 @@ def test_dimensions_property_assertion(self, mock_sentence_transformer):
123155
self.mock_model.get_sentence_embedding_dimension.return_value = "invalid"
124156

125157
# Now the assertion should fail during initialization
126-
with self.assertRaises(AssertionError):
158+
with self.assertRaises(ValueError):
127159
embedder = SentenceTransformerEmbedder(self.config)
128160

129161
@patch('ragl.embed.sentencetransformer.SentenceTransformer')

0 commit comments

Comments
 (0)