Skip to content

Commit c7694c9

Browse files
committed
client: Add 'process_bulk' and 'redact' parameters
Add client support for the 'process_bulk' task, which allows providing a list of documents to be annotated, and add missing configuration options to the 'redact' method. Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 9d4da2a commit c7694c9

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

client/cogstack_model_gateway_client/client.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
from collections.abc import Iterable
34
from functools import wraps
45

56
import httpx
@@ -46,6 +47,7 @@ async def submit_task(
4647
model_name: str = None,
4748
task: str = None,
4849
data=None,
50+
json=None,
4951
files=None,
5052
params=None,
5153
headers=None,
@@ -57,7 +59,9 @@ async def submit_task(
5759
if not model_name:
5860
raise ValueError("Please provide a model name or set a default model for the client.")
5961
url = f"{self.base_url}/models/{model_name}/tasks/{task}"
60-
resp = await self._client.post(url, data=data, files=files, params=params, headers=headers)
62+
resp = await self._client.post(
63+
url, data=data, json=json, files=files, params=params, headers=headers
64+
)
6165
resp.raise_for_status()
6266
task_info = resp.json()
6367
if wait_for_completion:
@@ -84,18 +88,51 @@ async def process(
8488
return_result=return_result,
8589
)
8690

91+
async def process_bulk(
92+
self,
93+
texts: list[str],
94+
model_name: str = None,
95+
wait_for_completion: bool = True,
96+
return_result: bool = True,
97+
):
98+
"""Generate annotations for a list of texts."""
99+
return await self.submit_task(
100+
model_name=model_name,
101+
task="process_bulk",
102+
json=texts,
103+
headers={"Content-Type": "application/json"},
104+
wait_for_completion=wait_for_completion,
105+
return_result=return_result,
106+
)
107+
87108
async def redact(
88109
self,
89110
text: str,
111+
concepts_to_keep: Iterable[str] = None,
112+
warn_on_no_redaction: bool = None,
113+
mask: str = None,
114+
hash: bool = None,
90115
model_name: str = None,
91116
wait_for_completion: bool = True,
92117
return_result: bool = True,
93118
):
94119
"""Redact sensitive information from the provided text."""
120+
params = {
121+
k: v
122+
for k, v in {
123+
"concepts_to_keep": concepts_to_keep,
124+
"warn_on_no_redaction": warn_on_no_redaction,
125+
"mask": mask,
126+
"hash": hash,
127+
}.items()
128+
if v is not None
129+
} or None
130+
95131
return await self.submit_task(
96132
model_name=model_name,
97133
task="redact",
98134
data=text,
135+
params=params,
99136
headers={"Content-Type": "text/plain"},
100137
wait_for_completion=wait_for_completion,
101138
return_result=return_result,
@@ -225,6 +262,9 @@ def submit_task(self, *args, **kwargs):
225262
def process(self, *args, **kwargs):
226263
return self._loop.run_until_complete(self._client.process(*args, **kwargs))
227264

265+
def process_bulk(self, *args, **kwargs):
266+
return self._loop.run_until_complete(self._client.process_bulk(*args, **kwargs))
267+
228268
def redact(self, *args, **kwargs):
229269
return self._loop.run_until_complete(self._client.redact(*args, **kwargs))
230270

tests/unit/client/test_client.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ async def test_submit_task_success(mock_httpx_async_client):
7676
mock_client_instance.post.assert_awaited_once_with(
7777
"http://test-gateway.com/models/my_model/tasks/process",
7878
data="some text",
79+
json=None,
7980
files=None,
8081
params=None,
8182
headers=None,
@@ -100,6 +101,7 @@ async def test_submit_task_with_default_model(mock_httpx_async_client):
100101
mock_client_instance.post.assert_awaited_once_with(
101102
"http://test-gateway.com/models/default_model/tasks/process",
102103
data="some text",
104+
json=None,
103105
files=None,
104106
params=None,
105107
headers=None,
@@ -176,16 +178,35 @@ async def test_process_method(mocker):
176178
)
177179

178180

181+
@pytest.mark.asyncio
182+
async def test_process_bulk_method(mocker):
183+
"""Test the process_bulk method."""
184+
async with GatewayClient(base_url="http://test-gateway.com") as client:
185+
mock_submit_task = mocker.patch.object(client, "submit_task", new=AsyncMock())
186+
await client.process_bulk(texts=["text1", "text2"], model_name="bulk_model")
187+
mock_submit_task.assert_awaited_once_with(
188+
model_name="bulk_model",
189+
task="process_bulk",
190+
json=["text1", "text2"],
191+
headers={"Content-Type": "application/json"},
192+
wait_for_completion=True,
193+
return_result=True,
194+
)
195+
196+
179197
@pytest.mark.asyncio
180198
async def test_redact_method(mocker):
181199
"""Test the redact method."""
182200
async with GatewayClient(base_url="http://test-gateway.com") as client:
183201
mock_submit_task = mocker.patch.object(client, "submit_task", new=AsyncMock())
184-
await client.redact(text="sensitive text", model_name="deid_model")
202+
await client.redact(
203+
text="sensitive text", concepts_to_keep=["label1", "label2"], model_name="deid_model"
204+
)
185205
mock_submit_task.assert_awaited_once_with(
186206
model_name="deid_model",
187207
task="redact",
188208
data="sensitive text",
209+
params={"concepts_to_keep": ["label1", "label2"]},
189210
headers={"Content-Type": "text/plain"},
190211
wait_for_completion=True,
191212
return_result=True,

0 commit comments

Comments
 (0)