1414
1515logger = logging .getLogger (__name__ )
1616
17+ # Valid pooling methods
18+ VALID_POOLING_METHODS = {"mean" , "last_token" , "eos_token" }
19+
1720
1821class Encoder :
1922 def __init__ (
@@ -38,11 +41,16 @@ def __init__(
3841 """
3942 self .backend_name = backend
4043 self .backend_instance : Backend
44+
45+ logger .debug (f"Initializing Encoder with model='{ model_name } ', backend='{ backend } '" )
46+ logger .debug (f"Device: { device } , Quantization: { quantization } " )
4147
4248 if backend == "transformers" :
49+ logger .debug ("Loading Transformers backend..." )
4350 self .backend_instance = TransformersBackend (
4451 model_name , device = device , quantization = quantization , ** kwargs
4552 )
53+ logger .debug ("Transformers backend loaded successfully" )
4654 elif backend == "vllm" :
4755 if VLLMBackend is None :
4856 raise ImportError (
@@ -53,10 +61,12 @@ def __init__(
5361 # vLLM backend requires a strict string for device (e.g. "cuda").
5462 # If 'device' is None (auto), default to "cuda".
5563 vllm_device = device if device is not None else "cuda"
56-
64+
65+ logger .debug (f"Loading vLLM backend with device='{ vllm_device } '..." )
5766 self .backend_instance = VLLMBackend (
5867 model_name , device = vllm_device , quantization = quantization , ** kwargs
5968 )
69+ logger .debug ("vLLM backend loaded successfully" )
6070 else :
6171 raise ValueError (
6272 f"Unknown backend: { backend } . Supported backends are 'transformers' and 'vllm'."
@@ -70,7 +80,7 @@ def encode(
7080 prompt_template : Optional [str ] = None ,
7181 batch_size : Optional [int ] = None ,
7282 ** kwargs : Any ,
73- ) -> Any :
83+ ) -> torch . Tensor :
7484 """
7585 Encode text into embeddings.
7686
@@ -85,10 +95,14 @@ def encode(
8595 prompt_template: Optional prompt template ('prompteol', 'pcoteol', 'ke').
8696 When specified, wraps the input text with the template.
8797 batch_size: Batch size for processing. If None, processes all inputs at once.
98+ Must be > 0 if provided.
8899 **kwargs: Backend specific arguments.
89100
90101 Returns:
91- Embeddings as numpy array or torch tensor.
102+ Embeddings as torch tensor.
103+
104+ Raises:
105+ ValueError: If pooling_method is invalid or batch_size <= 0.
92106 """
93107 # Smart default: use last_token pooling when template is provided
94108 if pooling_method is None :
@@ -97,15 +111,35 @@ def encode(
97111 else :
98112 pooling_method = "mean"
99113
114+ # Validate pooling_method
115+ if pooling_method not in VALID_POOLING_METHODS :
116+ raise ValueError (
117+ f"Invalid pooling_method: '{ pooling_method } '. "
118+ f"Valid options are: { ', ' .join (sorted (VALID_POOLING_METHODS ))} "
119+ )
120+
121+ # Validate batch_size
122+ if batch_size is not None and batch_size <= 0 :
123+ raise ValueError (
124+ f"batch_size must be a positive integer, got: { batch_size } "
125+ )
126+
100127 if isinstance (text , str ):
101128 text = [text ]
129+
130+ logger .debug (
131+ f"Encoding { len (text )} text(s) with pooling_method='{ pooling_method } ', "
132+ f"layer_index={ layer_index } , prompt_template={ prompt_template } "
133+ )
102134
103135 if batch_size is None :
136+ logger .debug ("Processing all inputs in a single batch" )
104137 return self .backend_instance .encode (
105138 text , pooling_method = pooling_method , layer_index = layer_index ,
106139 prompt_template = prompt_template , ** kwargs
107140 )
108141
142+ logger .debug (f"Processing in batches of size { batch_size } " )
109143 results = []
110144 total = len (text )
111145
0 commit comments