Skip to content

Commit 381b5d2

Browse files
authored
fix: Implement context management for clients (#172)
We missed closing httpx clients if they are CompassClient-owned (or other client types). This PR addresses this issue.
1 parent 51a6ab8 commit 381b5d2

File tree

5 files changed

+274
-26
lines changed

5 files changed

+274
-26
lines changed

cohere_compass/clients/compass.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dataclasses import dataclass
1919
from datetime import timedelta
2020
from statistics import mean
21+
from types import TracebackType
2122
from typing import Any, Literal
2223

2324
# 3rd party imports
@@ -267,9 +268,9 @@ def __init__(
267268
if httpx_client.timeout.read
268269
else DEFAULT_COMPASS_CLIENT_TIMEOUT
269270
)
270-
self.httpx_client = httpx_client or httpx.Client(
271-
timeout=self.timeout.total_seconds()
272-
)
271+
self.httpx = httpx_client or httpx.Client(timeout=self.timeout.total_seconds())
272+
self._own_httpx_client = httpx_client is None
273+
self._closed = False
273274

274275
self.bearer_token = bearer_token
275276

@@ -281,8 +282,23 @@ def __init__(
281282
self.retry_wait = retry_wait
282283

283284
def close(self):
284-
"""Close the HTTP client connection."""
285-
self.httpx_client.close()
285+
"""Close the httpx client if it was created by this instance."""
286+
if self._own_httpx_client and not self._closed:
287+
self.httpx.close()
288+
self._closed = True
289+
290+
def __enter__(self):
291+
"""For use by "with" statements."""
292+
return self
293+
294+
def __exit__(
295+
self,
296+
exc_type: type[BaseException] | None,
297+
exc_value: BaseException | None,
298+
traceback: TracebackType | None,
299+
) -> None:
300+
"""For use by "with" statements."""
301+
self.close()
286302

287303
def get_models(
288304
self,
@@ -1392,27 +1408,27 @@ def _send_http_request(
13921408
headers = {"Authorization": f"Bearer {self.bearer_token}"}
13931409

13941410
if http_method == "GET":
1395-
response = self.httpx_client.get(
1411+
response = self.httpx.get(
13961412
target_path,
13971413
headers=headers,
13981414
timeout=timeout.total_seconds(),
13991415
)
14001416
elif http_method == "POST":
1401-
response = self.httpx_client.post(
1417+
response = self.httpx.post(
14021418
target_path,
14031419
json=data_dict,
14041420
headers=headers,
14051421
timeout=timeout.total_seconds(),
14061422
)
14071423
elif http_method == "PUT":
1408-
response = self.httpx_client.put(
1424+
response = self.httpx.put(
14091425
target_path,
14101426
json=data_dict,
14111427
headers=headers,
14121428
timeout=timeout.total_seconds(),
14131429
)
14141430
elif http_method == "DELETE":
1415-
response = self.httpx_client.delete(
1431+
response = self.httpx.delete(
14161432
target_path,
14171433
headers=headers,
14181434
timeout=timeout.total_seconds(),

cohere_compass/clients/compass_async.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from collections.abc import AsyncIterable, Iterable
1515
from datetime import timedelta
1616
from statistics import mean
17+
from types import TracebackType
1718
from typing import Any, Literal
1819

1920
# 3rd party imports
@@ -133,9 +134,11 @@ def __init__(
133134
if httpx_client.timeout.read
134135
else DEFAULT_COMPASS_CLIENT_TIMEOUT
135136
)
136-
self.httpx_client = httpx_client or httpx.AsyncClient(
137+
self.httpx = httpx_client or httpx.AsyncClient(
137138
timeout=self.timeout.total_seconds()
138139
)
140+
self._own_httpx_client = httpx_client is None
141+
self._closed = False
139142

140143
self.bearer_token = bearer_token
141144

@@ -149,8 +152,25 @@ def __init__(
149152
self.retry_wait = retry_wait
150153

151154
async def aclose(self):
152-
"""Close the HTTP client."""
153-
await self.httpx_client.aclose()
155+
"""Close the httpx client if it was created by the CompassAsyncClient."""
156+
if self._own_httpx_client and not self._closed:
157+
await self.httpx.aclose()
158+
self._closed = True
159+
160+
close = aclose # Alias for consistency with sync client
161+
162+
async def __aenter__(self):
163+
"""For use by "async with" statements."""
164+
return self
165+
166+
async def __aexit__(
167+
self,
168+
exc_type: type[BaseException] | None,
169+
exc_value: BaseException | None,
170+
traceback: TracebackType | None,
171+
) -> None:
172+
"""For use by "async with" statements."""
173+
await self.aclose()
154174

155175
async def get_models(
156176
self,
@@ -1265,27 +1285,27 @@ async def _send_http_request(
12651285
headers = {"Authorization": f"Bearer {self.bearer_token}"}
12661286

12671287
if http_method == "GET":
1268-
response = await self.httpx_client.get(
1288+
response = await self.httpx.get(
12691289
target_path,
12701290
headers=headers,
12711291
timeout=timeout.total_seconds(),
12721292
)
12731293
elif http_method == "POST":
1274-
response = await self.httpx_client.post(
1294+
response = await self.httpx.post(
12751295
target_path,
12761296
json=data_dict,
12771297
headers=headers,
12781298
timeout=timeout.total_seconds(),
12791299
)
12801300
elif http_method == "PUT":
1281-
response = await self.httpx_client.put(
1301+
response = await self.httpx.put(
12821302
target_path,
12831303
json=data_dict,
12841304
headers=headers,
12851305
timeout=timeout.total_seconds(),
12861306
)
12871307
elif http_method == "DELETE":
1288-
response = await self.httpx_client.delete(
1308+
response = await self.httpx.delete(
12891309
target_path,
12901310
headers=headers,
12911311
timeout=timeout.total_seconds(),

cohere_compass/clients/parser.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from collections.abc import Callable, Iterable
1313
from concurrent.futures import ThreadPoolExecutor
1414
from datetime import timedelta
15+
from types import TracebackType
1516
from typing import Any
1617

1718
# 3rd party imports
@@ -42,9 +43,7 @@
4243
)
4344
from cohere_compass.utils.fs import open_document, scan_folder
4445
from cohere_compass.utils.iter import imap_parallel
45-
from cohere_compass.utils.retry import (
46-
is_retryable_compass_exception,
47-
)
46+
from cohere_compass.utils.retry import is_retryable_compass_exception
4847

4948
Fn_or_Dict = dict[str, Any] | Callable[[CompassDocument], dict[str, Any]]
5049

@@ -123,15 +122,34 @@ def __init__(
123122
if httpx_client.timeout.read
124123
else DEFAULT_COMPASS_PARSER_CLIENT_TIMEOUT
125124
)
126-
self.httpx_client = httpx_client or httpx.Client(
127-
timeout=self.timeout.total_seconds()
128-
)
125+
self.httpx = httpx_client or httpx.Client(timeout=self.timeout.total_seconds())
126+
self._own_httpx_client = httpx_client is None
127+
self._closed = False
129128

130129
self.metadata_config = metadata_config
131130
logger.info(
132131
f"CompassParserClient initialized with parser_url: {self.parser_url}"
133132
)
134133

134+
def close(self):
135+
"""Close the httpx client if it was created by this instance."""
136+
if self._own_httpx_client and not self._closed:
137+
self.httpx.close()
138+
self._closed = True
139+
140+
def __enter__(self):
141+
"""For use by "with" statements."""
142+
return self
143+
144+
def __exit__(
145+
self,
146+
exc_type: type[BaseException] | None,
147+
exc_value: BaseException | None,
148+
traceback: TracebackType | None,
149+
) -> None:
150+
"""For use by "with" statements."""
151+
self.close()
152+
135153
def process_folder(
136154
self,
137155
*,
@@ -390,7 +408,7 @@ def _process_file_bytes(
390408
headers = {"Authorization": f"Bearer {self.bearer_token}"}
391409

392410
with handle_httpx_exceptions():
393-
res = self.httpx_client.post(
411+
res = self.httpx.post(
394412
url=f"{self.parser_url}/v1/process_file",
395413
data={"data": json.dumps(params.model_dump())},
396414
files={"file": (filename, file_bytes)},

cohere_compass/clients/parser_async.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from collections.abc import Callable
1212
from concurrent.futures import ThreadPoolExecutor
1313
from datetime import timedelta
14+
from types import TracebackType
1415
from typing import Any
1516

1617
# 3rd party imports
@@ -39,9 +40,7 @@
3940
)
4041
from cohere_compass.utils.asyn import async_map
4142
from cohere_compass.utils.fs import open_document, scan_folder
42-
from cohere_compass.utils.retry import (
43-
is_retryable_compass_exception,
44-
)
43+
from cohere_compass.utils.retry import is_retryable_compass_exception
4544

4645
Fn_or_Dict = dict[str, Any] | Callable[[CompassDocument], dict[str, Any]]
4746

@@ -123,12 +122,35 @@ def __init__(
123122
self.httpx = httpx_client or httpx.AsyncClient(
124123
timeout=self.timeout.total_seconds()
125124
)
125+
self._own_httpx_client = httpx_client is None
126+
self._closed = False
126127

127128
self.metadata_config = metadata_config
128129
logger.info(
129130
f"CompassParserClient initialized with parser_url: {self.parser_url}"
130131
)
131132

133+
async def aclose(self):
134+
"""Close the httpx client if it was created by the CompassParserAsyncClient."""
135+
if self._own_httpx_client and not self._closed:
136+
await self.httpx.aclose()
137+
self._closed = True
138+
139+
close = aclose # Alias for consistency with sync client
140+
141+
async def __aenter__(self):
142+
"""For use by "async with" statements."""
143+
return self
144+
145+
async def __aexit__(
146+
self,
147+
exc_type: type[BaseException] | None,
148+
exc_value: BaseException | None,
149+
traceback: TracebackType | None,
150+
) -> None:
151+
"""For use by "async with" statements."""
152+
await self.aclose()
153+
132154
def process_folder(
133155
self,
134156
*,

0 commit comments

Comments
 (0)