Skip to content

Commit ac40282

Browse files
committed
[changes] Fixed AtlanClient event stream handling
1 parent 77e28f5 commit ac40282

File tree

5 files changed

+148
-39
lines changed

5 files changed

+148
-39
lines changed

pyatlan/client/asset.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,20 @@
2525
)
2626
from warnings import warn
2727

28-
import httpx
2928
from pydantic.v1 import (
3029
StrictStr,
3130
ValidationError,
3231
constr,
3332
parse_obj_as,
3433
validate_arguments,
3534
)
36-
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
35+
from tenacity import (
36+
RetryError,
37+
retry,
38+
retry_if_exception_type,
39+
stop_after_attempt,
40+
wait_fixed,
41+
)
3742

3843
from pyatlan.client.common import ApiCaller
3944
from pyatlan.client.constants import (
@@ -844,7 +849,10 @@ def delete_by_guid(self, guid: Union[str, List[str]]) -> AssetMutationResponse:
844849
)
845850
response = AssetMutationResponse(**raw_json)
846851
for asset in response.assets_deleted(asset_type=Asset):
847-
self._wait_till_deleted(asset)
852+
try:
853+
self._wait_till_deleted(asset)
854+
except RetryError as err:
855+
raise ErrorCode.RETRY_OVERRUN.exception_with_parameters() from err
848856
return response
849857

850858
@retry(
@@ -854,12 +862,9 @@ def delete_by_guid(self, guid: Union[str, List[str]]) -> AssetMutationResponse:
854862
wait=wait_fixed(1),
855863
)
856864
def _wait_till_deleted(self, asset: Asset):
857-
try:
858-
asset = self.retrieve_minimal(guid=asset.guid, asset_type=Asset)
859-
if asset.status == EntityStatus.DELETED:
860-
return
861-
except httpx.TransportError as err:
862-
raise ErrorCode.RETRY_OVERRUN.exception_with_parameters() from err
865+
asset = self.retrieve_minimal(guid=asset.guid, asset_type=Asset)
866+
if asset.status == EntityStatus.DELETED:
867+
return
863868

864869
@validate_arguments
865870
def restore(self, asset_type: Type[A], qualified_name: str) -> bool:

pyatlan/client/atlan.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import copy
88
import json
99
import logging
10-
import os
11-
import shutil
1210
import uuid
1311
from contextvars import ContextVar
1412
from http import HTTPStatus
@@ -129,7 +127,7 @@ class AtlanClient(BaseSettings):
129127
read_timeout: float = 900.0 # 15 mins
130128
retry: Retry = DEFAULT_RETRY
131129
_401_has_retried: ContextVar[bool] = ContextVar("_401_has_retried", default=False)
132-
_session: httpx.Client = PrivateAttr(default_factory=lambda: httpx.Client())
130+
_session: httpx.Client = PrivateAttr()
133131
_request_params: dict = PrivateAttr()
134132
_user_id: Optional[str] = PrivateAttr(default=None)
135133
_workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None)
@@ -169,7 +167,7 @@ def __init__(self, **data):
169167
"authorization": f"Bearer {self.api_key}",
170168
}
171169
}
172-
# Configure httpx client with retry transport
170+
# Configure httpx client with the provided retry settings
173171
self._session = httpx.Client(
174172
transport=RetryTransport(retry=self.retry),
175173
headers={
@@ -409,8 +407,9 @@ def update_headers(self, header: Dict[str, str]):
409407

410408
def _handle_file_download(self, raw_response: Any, file_path: str) -> str:
411409
try:
412-
download_file = open(file_path, "wb")
413-
shutil.copyfileobj(raw_response, download_file)
410+
with open(file_path, "wb") as download_file:
411+
for chunk in raw_response:
412+
download_file.write(chunk)
414413
except Exception as err:
415414
raise ErrorCode.UNABLE_TO_DOWNLOAD_FILE.exception_with_parameters(
416415
str((hasattr(err, "strerror") and err.strerror) or err), file_path
@@ -441,15 +440,49 @@ def _call_api_internal(
441440
timeout=timeout,
442441
)
443442
elif api.consumes == EVENT_STREAM and api.produces == EVENT_STREAM:
444-
response = self._session.request(
443+
with self._session.stream(
445444
api.method.value,
446445
path,
447446
**params,
448-
stream=True,
449447
timeout=timeout,
450-
)
451-
if download_file_path:
452-
return self._handle_file_download(response.raw, download_file_path)
448+
) as stream_response:
449+
if download_file_path:
450+
return self._handle_file_download(
451+
stream_response.iter_raw(), download_file_path
452+
)
453+
454+
# For event streams, we need to read the content while the stream is open
455+
# Store the response data and create a mock response object for common processing
456+
content = stream_response.read()
457+
text = content.decode("utf-8") if content else ""
458+
lines = []
459+
460+
# Only process lines for successful responses to avoid errors on error responses
461+
if stream_response.status_code == api.expected_status:
462+
# Reset stream position and get lines
463+
lines = text.splitlines() if text else []
464+
465+
response_data = {
466+
"status_code": stream_response.status_code,
467+
"headers": stream_response.headers,
468+
"text": text,
469+
"content": content,
470+
"lines": lines,
471+
}
472+
473+
# Create a simple namespace object to mimic the response interface
474+
response = SimpleNamespace(
475+
status_code=response_data["status_code"],
476+
headers=response_data["headers"],
477+
text=response_data["text"],
478+
content=response_data["content"],
479+
_stream_lines=response_data[
480+
"lines"
481+
], # Store lines for event processing
482+
json=lambda: json.loads(response_data["text"])
483+
if response_data["text"]
484+
else {},
485+
)
453486
else:
454487
response = self._session.request(
455488
api.method.value,
@@ -496,14 +529,16 @@ def _call_api_internal(
496529
response,
497530
)
498531
if api.consumes == EVENT_STREAM and api.produces == EVENT_STREAM:
499-
for line in response.iter_lines(decode_unicode=True):
500-
if not line:
501-
continue
502-
if not line.startswith("data: "):
503-
raise ErrorCode.UNABLE_TO_DESERIALIZE.exception_with_parameters(
504-
line
505-
)
506-
events.append(json.loads(line.split("data: ")[1]))
532+
# Process event stream using stored lines from the streaming response
533+
if hasattr(response, "_stream_lines"):
534+
for line in response._stream_lines:
535+
if not line:
536+
continue
537+
if not line.startswith("data: "):
538+
raise ErrorCode.UNABLE_TO_DESERIALIZE.exception_with_parameters(
539+
line
540+
)
541+
events.append(json.loads(line.split("data: ")[1]))
507542
if text_response:
508543
response_ = response.text
509544
else:

tests/integration/test_index_search.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import httpx
1111
import pytest
1212
from httpx_retries import Retry
13+
from pydantic.v1 import HttpUrl
1314

1415
from pyatlan.cache.source_tag_cache import SourceTagName
1516
from pyatlan.client.asset import LOGGER, IndexSearchResults, Persona, Purpose
@@ -854,12 +855,15 @@ def test_read_timeout(client: AtlanClient):
854855

855856

856857
def test_connect_timeout(client: AtlanClient):
857-
request = (FluentSearch().select()).to_request()
858+
request = FluentSearch().select().to_request()
859+
860+
# Use a non-routable IP that will definitely timeout
861+
# 192.0.2.1 is reserved for documentation/testing
858862
with client_connection(
859-
client=client, connect_timeout=0.0001, retry=Retry(total=0)
863+
client=client,
864+
base_url=HttpUrl("http://192.0.2.1:80", scheme="http"), # Non-routable test IP
865+
connect_timeout=0.001,
866+
retry=Retry(total=1),
860867
) as timed_client:
861-
with pytest.raises(
862-
httpx.ConnectTimeout,
863-
match="timed out",
864-
):
868+
with pytest.raises(httpx.ConnectTimeout):
865869
timed_client.asset.search(criteria=request)

tests/unit/test_file_client.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,22 @@ def mock_session():
7979
mock_response = Mock()
8080
mock_response.status_code = 200
8181
mock_response.raw = open(UPLOAD_FILE_PATH, "rb")
82-
mock_session.request.return_value = mock_response
82+
mock_response.headers = {}
83+
84+
# Mock the methods our streaming code expects
85+
mock_response.read.return_value = b"test content"
86+
87+
def mock_iter_raw(chunk_size=None):
88+
# Use the actual expected content from upload.txt
89+
content = b"test data 12345.\n"
90+
yield content
91+
92+
mock_response.iter_raw = mock_iter_raw
93+
94+
# Use Mock's context manager support
95+
mock_session.stream.return_value.__enter__.return_value = mock_response
96+
mock_session.stream.return_value.__exit__.return_value = None
97+
8398
yield mock_session
8499
assert os.path.exists(DOWNLOAD_FILE_PATH)
85100
os.remove(DOWNLOAD_FILE_PATH)
@@ -91,10 +106,34 @@ def mock_session_invalid():
91106
mock_response = Mock()
92107
mock_response.status_code = 200
93108
mock_response.raw = "not a bytes-like object"
94-
mock_session.request.return_value = mock_response
109+
mock_response.headers = {}
110+
111+
# Mock the methods our streaming code expects
112+
mock_response.read.return_value = b"test content"
113+
114+
def mock_iter_raw(chunk_size=None):
115+
# Return a generator that will fail during iteration
116+
# This simulates a case where the response object is invalid
117+
class BadIterator:
118+
def __iter__(self):
119+
return self
120+
121+
def __next__(self):
122+
# Simulate the error that would happen in real scenario
123+
raise AttributeError("'str' object has no attribute 'read'")
124+
125+
return BadIterator()
126+
127+
mock_response.iter_raw = mock_iter_raw
128+
129+
# Use Mock's context manager support
130+
mock_session.stream.return_value.__enter__.return_value = mock_response
131+
mock_session.stream.return_value.__exit__.return_value = None
132+
95133
yield mock_session
96-
assert os.path.exists(DOWNLOAD_FILE_PATH)
97-
os.remove(DOWNLOAD_FILE_PATH)
134+
# Don't assert file exists for invalid case since error should prevent creation
135+
if os.path.exists(DOWNLOAD_FILE_PATH):
136+
os.remove(DOWNLOAD_FILE_PATH)
98137

99138

100139
@pytest.mark.parametrize("method, params", TEST_FILE_CLIENT_METHODS.items())
@@ -200,7 +239,7 @@ def test_file_client_download_file(client, s3_presigned_url, mock_session):
200239
presigned_url=s3_presigned_url, file_path=DOWNLOAD_FILE_PATH
201240
)
202241
assert response == DOWNLOAD_FILE_PATH
203-
assert mock_session.request.call_count == 1
242+
assert mock_session.stream.call_count == 1
204243
# The file should exist after calling the method
205244
assert os.path.exists(DOWNLOAD_FILE_PATH)
206245
assert open(DOWNLOAD_FILE_PATH, "r").read() == "test data 12345.\n"

tests/unit/test_query_client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,23 @@ def mock_session():
4444
mock_response = Mock()
4545
mock_response.status_code = 200
4646
mock_response.content = "test-content"
47+
mock_response.headers = {}
48+
4749
with open(QUERY_RESPONSES, "r", encoding="utf-8") as file:
4850
lines_from_file = [line.strip() for line in file.readlines()]
4951
mock_response.iter_lines.return_value = lines_from_file
52+
53+
# Mock the methods our streaming code expects
54+
file_content = "\n".join(lines_from_file)
55+
mock_response.read.return_value = file_content.encode("utf-8")
56+
mock_response.text = file_content
57+
58+
# Support both old request-style and new stream-style
5059
mock_session.request.return_value = mock_response
60+
61+
# Use Mock's context manager support for streaming
62+
mock_session.stream.return_value.__enter__.return_value = mock_response
63+
mock_session.stream.return_value.__exit__.return_value = None
5164
yield mock_session
5265

5366

@@ -90,8 +103,21 @@ def test_stream_get_raises_error(
90103
mock_response = Mock()
91104
mock_response.status_code = 200
92105
mock_response.content = "test-content"
106+
mock_response.headers = {}
93107
mock_response.iter_lines.return_value = test_response
108+
109+
# Mock the methods our streaming code expects
110+
file_content = "\n".join(test_response)
111+
mock_response.read.return_value = file_content.encode("utf-8")
112+
mock_response.text = file_content
113+
114+
# Support both old request-style and new stream-style
94115
mock_session.request.return_value = mock_response
116+
117+
# Use Mock's context manager support for streaming
118+
mock_session.stream.return_value.__enter__.return_value = mock_response
119+
mock_session.stream.return_value.__exit__.return_value = None
120+
95121
with pytest.raises(test_error) as err:
96122
client.queries.stream(request=query_request)
97123
assert error_msg in str(err.value)

0 commit comments

Comments
 (0)