Skip to content

Commit 82d2354

Browse files
committed
[changes] Fixed AtlanClient event stream handling
1 parent 1a0f464 commit 82d2354

File tree

4 files changed

+150
-42
lines changed

4 files changed

+150
-42
lines changed

pyatlan/client/asset.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,20 @@
2525
)
2626
from warnings import warn
2727

28-
import httpx
2928
from pydantic.v1 import (
3029
StrictStr,
3130
ValidationError,
3231
constr,
3332
parse_obj_as,
3433
validate_arguments,
3534
)
36-
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
35+
from tenacity import (
36+
RetryError,
37+
retry,
38+
retry_if_exception_type,
39+
stop_after_attempt,
40+
wait_fixed,
41+
)
3742

3843
from pyatlan.client.common import ApiCaller
3944
from pyatlan.client.constants import (
@@ -844,7 +849,10 @@ def delete_by_guid(self, guid: Union[str, List[str]]) -> AssetMutationResponse:
844849
)
845850
response = AssetMutationResponse(**raw_json)
846851
for asset in response.assets_deleted(asset_type=Asset):
847-
self._wait_till_deleted(asset)
852+
try:
853+
self._wait_till_deleted(asset)
854+
except RetryError as err:
855+
raise ErrorCode.RETRY_OVERRUN.exception_with_parameters() from err
848856
return response
849857

850858
@retry(
@@ -854,12 +862,9 @@ def delete_by_guid(self, guid: Union[str, List[str]]) -> AssetMutationResponse:
854862
wait=wait_fixed(1),
855863
)
856864
def _wait_till_deleted(self, asset: Asset):
857-
try:
858-
asset = self.retrieve_minimal(guid=asset.guid, asset_type=Asset)
859-
if asset.status == EntityStatus.DELETED:
860-
return
861-
except httpx.TransportError as err:
862-
raise ErrorCode.RETRY_OVERRUN.exception_with_parameters() from err
865+
asset = self.retrieve_minimal(guid=asset.guid, asset_type=Asset)
866+
if asset.status == EntityStatus.DELETED:
867+
return
863868

864869
@validate_arguments
865870
def restore(self, asset_type: Type[A], qualified_name: str) -> bool:

pyatlan/client/atlan.py

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import copy
88
import json
99
import logging
10-
import shutil
1110
import uuid
1211
from contextvars import ContextVar
1312
from http import HTTPStatus
@@ -121,14 +120,27 @@ def log_response(response, *args, **kwargs):
121120
LOGGER.debug("URL: %s", response.request.url)
122121

123122

123+
def get_session():
124+
return httpx.Client(
125+
transport=RetryTransport(retry=DEFAULT_RETRY),
126+
headers={
127+
"x-atlan-agent": "sdk",
128+
"x-atlan-agent-id": "python",
129+
"x-atlan-client-origin": "product_sdk",
130+
"User-Agent": f"Atlan-PythonSDK/{VERSION}",
131+
},
132+
event_hooks={"response": [log_response]},
133+
)
134+
135+
124136
class AtlanClient(BaseSettings):
125137
base_url: Union[Literal["INTERNAL"], HttpUrl]
126138
api_key: str
127139
connect_timeout: float = 30.0 # 30 secs
128140
read_timeout: float = 900.0 # 15 mins
129141
retry: Retry = DEFAULT_RETRY
130142
_401_has_retried: ContextVar[bool] = ContextVar("_401_has_retried", default=False)
131-
_session: httpx.Client = PrivateAttr(default_factory=lambda: httpx.Client())
143+
_session: httpx.Client = PrivateAttr(default_factory=get_session)
132144
_request_params: dict = PrivateAttr()
133145
_user_id: Optional[str] = PrivateAttr(default=None)
134146
_workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None)
@@ -168,17 +180,6 @@ def __init__(self, **data):
168180
"authorization": f"Bearer {self.api_key}",
169181
}
170182
}
171-
# Configure httpx client with retry transport
172-
self._session = httpx.Client(
173-
transport=RetryTransport(retry=self.retry),
174-
headers={
175-
"x-atlan-agent": "sdk",
176-
"x-atlan-agent-id": "python",
177-
"x-atlan-client-origin": "product_sdk",
178-
"User-Agent": f"Atlan-PythonSDK/{VERSION}",
179-
},
180-
event_hooks={"response": [log_response]},
181-
)
182183
self._401_has_retried.set(False)
183184

184185
@property
@@ -342,8 +343,9 @@ def update_headers(self, header: Dict[str, str]):
342343

343344
def _handle_file_download(self, raw_response: Any, file_path: str) -> str:
344345
try:
345-
download_file = open(file_path, "wb")
346-
shutil.copyfileobj(raw_response, download_file)
346+
with open(file_path, "wb") as download_file:
347+
for chunk in raw_response:
348+
download_file.write(chunk)
347349
except Exception as err:
348350
raise ErrorCode.UNABLE_TO_DOWNLOAD_FILE.exception_with_parameters(
349351
str((hasattr(err, "strerror") and err.strerror) or err), file_path
@@ -374,15 +376,49 @@ def _call_api_internal(
374376
timeout=timeout,
375377
)
376378
elif api.consumes == EVENT_STREAM and api.produces == EVENT_STREAM:
377-
response = self._session.request(
379+
with self._session.stream(
378380
api.method.value,
379381
path,
380382
**params,
381-
stream=True,
382383
timeout=timeout,
383-
)
384-
if download_file_path:
385-
return self._handle_file_download(response.raw, download_file_path)
384+
) as stream_response:
385+
if download_file_path:
386+
return self._handle_file_download(
387+
stream_response.iter_raw(), download_file_path
388+
)
389+
390+
# For event streams, we need to read the content while the stream is open
391+
# Store the response data and create a mock response object for common processing
392+
content = stream_response.read()
393+
text = content.decode("utf-8") if content else ""
394+
lines = []
395+
396+
# Only process lines for successful responses to avoid errors on error responses
397+
if stream_response.status_code == api.expected_status:
398+
# Reset stream position and get lines
399+
lines = text.splitlines() if text else []
400+
401+
response_data = {
402+
"status_code": stream_response.status_code,
403+
"headers": stream_response.headers,
404+
"text": text,
405+
"content": content,
406+
"lines": lines,
407+
}
408+
409+
# Create a simple namespace object to mimic the response interface
410+
response = SimpleNamespace(
411+
status_code=response_data["status_code"],
412+
headers=response_data["headers"],
413+
text=response_data["text"],
414+
content=response_data["content"],
415+
_stream_lines=response_data[
416+
"lines"
417+
], # Store lines for event processing
418+
json=lambda: json.loads(response_data["text"])
419+
if response_data["text"]
420+
else {},
421+
)
386422
else:
387423
response = self._session.request(
388424
api.method.value,
@@ -429,14 +465,16 @@ def _call_api_internal(
429465
response,
430466
)
431467
if api.consumes == EVENT_STREAM and api.produces == EVENT_STREAM:
432-
for line in response.iter_lines(decode_unicode=True):
433-
if not line:
434-
continue
435-
if not line.startswith("data: "):
436-
raise ErrorCode.UNABLE_TO_DESERIALIZE.exception_with_parameters(
437-
line
438-
)
439-
events.append(json.loads(line.split("data: ")[1]))
468+
# Process event stream using stored lines from the streaming response
469+
if hasattr(response, "_stream_lines"):
470+
for line in response._stream_lines:
471+
if not line:
472+
continue
473+
if not line.startswith("data: "):
474+
raise ErrorCode.UNABLE_TO_DESERIALIZE.exception_with_parameters(
475+
line
476+
)
477+
events.append(json.loads(line.split("data: ")[1]))
440478
if text_response:
441479
response_ = response.text
442480
else:

tests/unit/test_file_client.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,22 @@ def mock_session():
7979
mock_response = Mock()
8080
mock_response.status_code = 200
8181
mock_response.raw = open(UPLOAD_FILE_PATH, "rb")
82-
mock_session.request.return_value = mock_response
82+
mock_response.headers = {}
83+
84+
# Mock the methods our streaming code expects
85+
mock_response.read.return_value = b"test content"
86+
87+
def mock_iter_raw(chunk_size=None):
88+
# Use the actual expected content from upload.txt
89+
content = b"test data 12345.\n"
90+
yield content
91+
92+
mock_response.iter_raw = mock_iter_raw
93+
94+
# Use Mock's context manager support
95+
mock_session.stream.return_value.__enter__.return_value = mock_response
96+
mock_session.stream.return_value.__exit__.return_value = None
97+
8398
yield mock_session
8499
assert os.path.exists(DOWNLOAD_FILE_PATH)
85100
os.remove(DOWNLOAD_FILE_PATH)
@@ -91,10 +106,34 @@ def mock_session_invalid():
91106
mock_response = Mock()
92107
mock_response.status_code = 200
93108
mock_response.raw = "not a bytes-like object"
94-
mock_session.request.return_value = mock_response
109+
mock_response.headers = {}
110+
111+
# Mock the methods our streaming code expects
112+
mock_response.read.return_value = b"test content"
113+
114+
def mock_iter_raw(chunk_size=None):
115+
# Return a generator that will fail during iteration
116+
# This simulates a case where the response object is invalid
117+
class BadIterator:
118+
def __iter__(self):
119+
return self
120+
121+
def __next__(self):
122+
# Simulate the error that would happen in real scenario
123+
raise AttributeError("'str' object has no attribute 'read'")
124+
125+
return BadIterator()
126+
127+
mock_response.iter_raw = mock_iter_raw
128+
129+
# Use Mock's context manager support
130+
mock_session.stream.return_value.__enter__.return_value = mock_response
131+
mock_session.stream.return_value.__exit__.return_value = None
132+
95133
yield mock_session
96-
assert os.path.exists(DOWNLOAD_FILE_PATH)
97-
os.remove(DOWNLOAD_FILE_PATH)
134+
# Don't assert file exists for invalid case since error should prevent creation
135+
if os.path.exists(DOWNLOAD_FILE_PATH):
136+
os.remove(DOWNLOAD_FILE_PATH)
98137

99138

100139
@pytest.mark.parametrize("method, params", TEST_FILE_CLIENT_METHODS.items())
@@ -200,7 +239,7 @@ def test_file_client_download_file(client, s3_presigned_url, mock_session):
200239
presigned_url=s3_presigned_url, file_path=DOWNLOAD_FILE_PATH
201240
)
202241
assert response == DOWNLOAD_FILE_PATH
203-
assert mock_session.request.call_count == 1
242+
assert mock_session.stream.call_count == 1
204243
# The file should exist after calling the method
205244
assert os.path.exists(DOWNLOAD_FILE_PATH)
206245
assert open(DOWNLOAD_FILE_PATH, "r").read() == "test data 12345.\n"

tests/unit/test_query_client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,23 @@ def mock_session():
4444
mock_response = Mock()
4545
mock_response.status_code = 200
4646
mock_response.content = "test-content"
47+
mock_response.headers = {}
48+
4749
with open(QUERY_RESPONSES, "r", encoding="utf-8") as file:
4850
lines_from_file = [line.strip() for line in file.readlines()]
4951
mock_response.iter_lines.return_value = lines_from_file
52+
53+
# Mock the methods our streaming code expects
54+
file_content = "\n".join(lines_from_file)
55+
mock_response.read.return_value = file_content.encode("utf-8")
56+
mock_response.text = file_content
57+
58+
# Support both old request-style and new stream-style
5059
mock_session.request.return_value = mock_response
60+
61+
# Use Mock's context manager support for streaming
62+
mock_session.stream.return_value.__enter__.return_value = mock_response
63+
mock_session.stream.return_value.__exit__.return_value = None
5164
yield mock_session
5265

5366

@@ -90,8 +103,21 @@ def test_stream_get_raises_error(
90103
mock_response = Mock()
91104
mock_response.status_code = 200
92105
mock_response.content = "test-content"
106+
mock_response.headers = {}
93107
mock_response.iter_lines.return_value = test_response
108+
109+
# Mock the methods our streaming code expects
110+
file_content = "\n".join(test_response)
111+
mock_response.read.return_value = file_content.encode("utf-8")
112+
mock_response.text = file_content
113+
114+
# Support both old request-style and new stream-style
94115
mock_session.request.return_value = mock_response
116+
117+
# Use Mock's context manager support for streaming
118+
mock_session.stream.return_value.__enter__.return_value = mock_response
119+
mock_session.stream.return_value.__exit__.return_value = None
120+
95121
with pytest.raises(test_error) as err:
96122
client.queries.stream(request=query_request)
97123
assert error_msg in str(err.value)

0 commit comments

Comments
 (0)