Skip to content

Commit 0813a29

Browse files
committed
Merge branch 'main' of github.com:j341nono/llemb
2 parents 3c015c1 + a963ebb commit 0813a29

File tree

4 files changed

+336
-8
lines changed

4 files changed

+336
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "llemb"
3-
version = "0.2.2"
3+
version = "0.3.0"
44
description = "Embedding extractor for decoder-only LLMs"
55
readme = "README.md"
66
requires-python = ">=3.9"

src/llemb/backends/vllm_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ def encode(
9292
else:
9393
pooling_method = "mean"
9494

95-
# vLLM backend warning for layer_index
95+
# vLLM backend only supports last layer (-1)
9696
if layer_index is not None and layer_index != -1:
97-
logger.warning(
98-
f"layer_index={layer_index} was requested, but vLLM backend currently "
99-
"only supports the last layer output. This parameter is ignored."
97+
raise ValueError(
98+
f"layer_index={layer_index} is not supported by vLLM backend. "
99+
"vLLM currently only supports the last layer (layer_index=-1)."
100100
)
101101

102102
if isinstance(text, str):

src/llemb/core.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
logger = logging.getLogger(__name__)
1616

17+
# Valid pooling methods
18+
VALID_POOLING_METHODS = {"mean", "last_token", "eos_token"}
19+
1720

1821
class Encoder:
1922
def __init__(
@@ -38,11 +41,16 @@ def __init__(
3841
"""
3942
self.backend_name = backend
4043
self.backend_instance: Backend
44+
45+
logger.debug(f"Initializing Encoder with model='{model_name}', backend='{backend}'")
46+
logger.debug(f"Device: {device}, Quantization: {quantization}")
4147

4248
if backend == "transformers":
49+
logger.debug("Loading Transformers backend...")
4350
self.backend_instance = TransformersBackend(
4451
model_name, device=device, quantization=quantization, **kwargs
4552
)
53+
logger.debug("Transformers backend loaded successfully")
4654
elif backend == "vllm":
4755
if VLLMBackend is None:
4856
raise ImportError(
@@ -53,10 +61,12 @@ def __init__(
5361
# vLLM backend requires a strict string for device (e.g. "cuda").
5462
# If 'device' is None (auto), default to "cuda".
5563
vllm_device = device if device is not None else "cuda"
56-
64+
65+
logger.debug(f"Loading vLLM backend with device='{vllm_device}'...")
5766
self.backend_instance = VLLMBackend(
5867
model_name, device=vllm_device, quantization=quantization, **kwargs
5968
)
69+
logger.debug("vLLM backend loaded successfully")
6070
else:
6171
raise ValueError(
6272
f"Unknown backend: {backend}. Supported backends are 'transformers' and 'vllm'."
@@ -70,7 +80,7 @@ def encode(
7080
prompt_template: Optional[str] = None,
7181
batch_size: Optional[int] = None,
7282
**kwargs: Any,
73-
) -> Any:
83+
) -> torch.Tensor:
7484
"""
7585
Encode text into embeddings.
7686
@@ -85,10 +95,14 @@ def encode(
8595
prompt_template: Optional prompt template ('prompteol', 'pcoteol', 'ke').
8696
When specified, wraps the input text with the template.
8797
batch_size: Batch size for processing. If None, processes all inputs at once.
98+
Must be > 0 if provided.
8899
**kwargs: Backend specific arguments.
89100
90101
Returns:
91-
Embeddings as numpy array or torch tensor.
102+
Embeddings as torch tensor.
103+
104+
Raises:
105+
ValueError: If pooling_method is invalid or batch_size <= 0.
92106
"""
93107
# Smart default: use last_token pooling when template is provided
94108
if pooling_method is None:
@@ -97,15 +111,35 @@ def encode(
97111
else:
98112
pooling_method = "mean"
99113

114+
# Validate pooling_method
115+
if pooling_method not in VALID_POOLING_METHODS:
116+
raise ValueError(
117+
f"Invalid pooling_method: '{pooling_method}'. "
118+
f"Valid options are: {', '.join(sorted(VALID_POOLING_METHODS))}"
119+
)
120+
121+
# Validate batch_size
122+
if batch_size is not None and batch_size <= 0:
123+
raise ValueError(
124+
f"batch_size must be a positive integer, got: {batch_size}"
125+
)
126+
100127
if isinstance(text, str):
101128
text = [text]
129+
130+
logger.debug(
131+
f"Encoding {len(text)} text(s) with pooling_method='{pooling_method}', "
132+
f"layer_index={layer_index}, prompt_template={prompt_template}"
133+
)
102134

103135
if batch_size is None:
136+
logger.debug("Processing all inputs in a single batch")
104137
return self.backend_instance.encode(
105138
text, pooling_method=pooling_method, layer_index=layer_index,
106139
prompt_template=prompt_template, **kwargs
107140
)
108141

142+
logger.debug(f"Processing in batches of size {batch_size}")
109143
results = []
110144
total = len(text)
111145

0 commit comments

Comments
 (0)