Skip to content

Commit a6b60e9

Browse files
authored
Fix classify and classify_batch for Python client (#608)
1 parent a446ae4 commit a6b60e9

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

clients/python/lorax/client.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict, Optional, List, AsyncIterator, Iterator, Union
88

99
from lorax.types import (
10+
BatchRequest,
1011
StreamResponse,
1112
Response,
1213
Request,
@@ -69,6 +70,7 @@ def __init__(
6970
self.base_url = base_url
7071
self.embed_endpoint = f"{base_url}/embed"
7172
self.classify_endpoint = f"{base_url}/classify"
73+
self.classify_batch_endpoint = f"{base_url}/classify_batch"
7274
self.headers = headers
7375
self.cookies = cookies
7476
self.timeout = timeout
@@ -470,8 +472,34 @@ def classify(self, inputs: str) -> ClassifyResponse:
470472
if resp.status_code != 200:
471473
raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None)
472474

473-
print(payload)
474-
return ClassifyResponse(**payload)
475+
return ClassifyResponse(entities=payload)
476+
477+
def classify_batch(self, inputs: List[str]) -> List[ClassifyResponse]:
478+
"""
479+
Given a list of inputs, run token classification on the text using the model
480+
481+
Args:
482+
inputs (`List[str]`):
483+
List of input texts
484+
485+
Returns:
486+
List[Entities]: Entities found in the input text
487+
"""
488+
request = BatchRequest(inputs=inputs)
489+
490+
resp = requests.post(
491+
self.classify_batch_endpoint,
492+
json=request.dict(by_alias=True),
493+
headers=self.headers,
494+
cookies=self.cookies,
495+
timeout=self.timeout,
496+
)
497+
498+
payload = resp.json()
499+
if resp.status_code != 200:
500+
raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None)
501+
502+
return [ClassifyResponse(entities=e) for e in payload]
475503

476504

477505
class AsyncClient:

clients/python/lorax/types.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,28 @@ def valid_best_of_stream(cls, field_value, values):
220220
return field_value
221221

222222

223+
class BatchRequest(BaseModel):
224+
# Prompt
225+
inputs: List[str]
226+
# Generation parameters
227+
parameters: Optional[Parameters] = None
228+
# Whether to stream output tokens
229+
stream: bool = False
230+
231+
@field_validator("inputs")
232+
def valid_input(cls, v):
233+
if not v:
234+
raise ValidationError("`inputs` cannot be empty")
235+
return v
236+
237+
@field_validator("stream")
238+
def valid_best_of_stream(cls, field_value, values):
239+
parameters = values.data["parameters"]
240+
if parameters is not None and parameters.best_of is not None and parameters.best_of > 1 and field_value:
241+
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
242+
return field_value
243+
244+
223245
# Decoder input tokens
224246
class InputToken(BaseModel):
225247
# Token ID from the model tokenizer

0 commit comments

Comments
 (0)