|
7 | 7 | from typing import Any, Dict, Optional, List, AsyncIterator, Iterator, Union |
8 | 8 |
|
9 | 9 | from lorax.types import ( |
| 10 | + BatchRequest, |
10 | 11 | StreamResponse, |
11 | 12 | Response, |
12 | 13 | Request, |
@@ -69,6 +70,7 @@ def __init__( |
69 | 70 | self.base_url = base_url |
70 | 71 | self.embed_endpoint = f"{base_url}/embed" |
71 | 72 | self.classify_endpoint = f"{base_url}/classify" |
| 73 | + self.classify_batch_endpoint = f"{base_url}/classify_batch" |
72 | 74 | self.headers = headers |
73 | 75 | self.cookies = cookies |
74 | 76 | self.timeout = timeout |
@@ -470,8 +472,34 @@ def classify(self, inputs: str) -> ClassifyResponse: |
470 | 472 | if resp.status_code != 200: |
471 | 473 | raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None) |
472 | 474 |
|
473 | | - print(payload) |
474 | | - return ClassifyResponse(**payload) |
| 475 | + return ClassifyResponse(entities=payload) |
| 476 | + |
| 477 | + def classify_batch(self, inputs: List[str]) -> List[ClassifyResponse]: |
| 478 | + """ |
| 479 | + Given a list of inputs, run token classification on the text using the model |
| 480 | +
|
| 481 | + Args: |
| 482 | + inputs (`List[str]`): |
| 483 | + List of input texts |
| 484 | + |
| 485 | + Returns: |
| 486 | + List[Entities]: Entities found in the input text |
| 487 | + """ |
| 488 | + request = BatchRequest(inputs=inputs) |
| 489 | + |
| 490 | + resp = requests.post( |
| 491 | + self.classify_batch_endpoint, |
| 492 | + json=request.dict(by_alias=True), |
| 493 | + headers=self.headers, |
| 494 | + cookies=self.cookies, |
| 495 | + timeout=self.timeout, |
| 496 | + ) |
| 497 | + |
| 498 | + payload = resp.json() |
| 499 | + if resp.status_code != 200: |
| 500 | + raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None) |
| 501 | + |
| 502 | + return [ClassifyResponse(entities=e) for e in payload] |
475 | 503 |
|
476 | 504 |
|
477 | 505 | class AsyncClient: |
|
0 commit comments