|
7 | 7 | import copy
|
8 | 8 | import json
|
9 | 9 | import logging
|
10 |
| -import shutil |
11 | 10 | import uuid
|
12 | 11 | from contextvars import ContextVar
|
13 | 12 | from http import HTTPStatus
|
@@ -121,14 +120,27 @@ def log_response(response, *args, **kwargs):
|
121 | 120 | LOGGER.debug("URL: %s", response.request.url)
|
122 | 121 |
|
123 | 122 |
|
| 123 | +def get_session(): |
| 124 | + return httpx.Client( |
| 125 | + transport=RetryTransport(retry=DEFAULT_RETRY), |
| 126 | + headers={ |
| 127 | + "x-atlan-agent": "sdk", |
| 128 | + "x-atlan-agent-id": "python", |
| 129 | + "x-atlan-client-origin": "product_sdk", |
| 130 | + "User-Agent": f"Atlan-PythonSDK/{VERSION}", |
| 131 | + }, |
| 132 | + event_hooks={"response": [log_response]}, |
| 133 | + ) |
| 134 | + |
| 135 | + |
124 | 136 | class AtlanClient(BaseSettings):
|
125 | 137 | base_url: Union[Literal["INTERNAL"], HttpUrl]
|
126 | 138 | api_key: str
|
127 | 139 | connect_timeout: float = 30.0 # 30 secs
|
128 | 140 | read_timeout: float = 900.0 # 15 mins
|
129 | 141 | retry: Retry = DEFAULT_RETRY
|
130 | 142 | _401_has_retried: ContextVar[bool] = ContextVar("_401_has_retried", default=False)
|
131 |
| - _session: httpx.Client = PrivateAttr(default_factory=lambda: httpx.Client()) |
| 143 | + _session: httpx.Client = PrivateAttr(default_factory=get_session) |
132 | 144 | _request_params: dict = PrivateAttr()
|
133 | 145 | _user_id: Optional[str] = PrivateAttr(default=None)
|
134 | 146 | _workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None)
|
@@ -168,17 +180,6 @@ def __init__(self, **data):
|
168 | 180 | "authorization": f"Bearer {self.api_key}",
|
169 | 181 | }
|
170 | 182 | }
|
171 |
| - # Configure httpx client with retry transport |
172 |
| - self._session = httpx.Client( |
173 |
| - transport=RetryTransport(retry=self.retry), |
174 |
| - headers={ |
175 |
| - "x-atlan-agent": "sdk", |
176 |
| - "x-atlan-agent-id": "python", |
177 |
| - "x-atlan-client-origin": "product_sdk", |
178 |
| - "User-Agent": f"Atlan-PythonSDK/{VERSION}", |
179 |
| - }, |
180 |
| - event_hooks={"response": [log_response]}, |
181 |
| - ) |
182 | 183 | self._401_has_retried.set(False)
|
183 | 184 |
|
184 | 185 | @property
|
@@ -342,8 +343,9 @@ def update_headers(self, header: Dict[str, str]):
|
342 | 343 |
|
343 | 344 | def _handle_file_download(self, raw_response: Any, file_path: str) -> str:
|
344 | 345 | try:
|
345 |
| - download_file = open(file_path, "wb") |
346 |
| - shutil.copyfileobj(raw_response, download_file) |
| 346 | + with open(file_path, "wb") as download_file: |
| 347 | + for chunk in raw_response: |
| 348 | + download_file.write(chunk) |
347 | 349 | except Exception as err:
|
348 | 350 | raise ErrorCode.UNABLE_TO_DOWNLOAD_FILE.exception_with_parameters(
|
349 | 351 | str((hasattr(err, "strerror") and err.strerror) or err), file_path
|
@@ -374,15 +376,49 @@ def _call_api_internal(
|
374 | 376 | timeout=timeout,
|
375 | 377 | )
|
376 | 378 | elif api.consumes == EVENT_STREAM and api.produces == EVENT_STREAM:
|
377 |
| - response = self._session.request( |
| 379 | + with self._session.stream( |
378 | 380 | api.method.value,
|
379 | 381 | path,
|
380 | 382 | **params,
|
381 |
| - stream=True, |
382 | 383 | timeout=timeout,
|
383 |
| - ) |
384 |
| - if download_file_path: |
385 |
| - return self._handle_file_download(response.raw, download_file_path) |
| 384 | + ) as stream_response: |
| 385 | + if download_file_path: |
| 386 | + return self._handle_file_download( |
| 387 | + stream_response.iter_raw(), download_file_path |
| 388 | + ) |
| 389 | + |
| 390 | + # For event streams, we need to read the content while the stream is open |
| 391 | + # Store the response data and create a mock response object for common processing |
| 392 | + content = stream_response.read() |
| 393 | + text = content.decode("utf-8") if content else "" |
| 394 | + lines = [] |
| 395 | + |
| 396 | + # Only process lines for successful responses to avoid errors on error responses |
| 397 | + if stream_response.status_code == api.expected_status: |
| 398 | + # Reset stream position and get lines |
| 399 | + lines = text.splitlines() if text else [] |
| 400 | + |
| 401 | + response_data = { |
| 402 | + "status_code": stream_response.status_code, |
| 403 | + "headers": stream_response.headers, |
| 404 | + "text": text, |
| 405 | + "content": content, |
| 406 | + "lines": lines, |
| 407 | + } |
| 408 | + |
| 409 | + # Create a simple namespace object to mimic the response interface |
| 410 | + response = SimpleNamespace( |
| 411 | + status_code=response_data["status_code"], |
| 412 | + headers=response_data["headers"], |
| 413 | + text=response_data["text"], |
| 414 | + content=response_data["content"], |
| 415 | + _stream_lines=response_data[ |
| 416 | + "lines" |
| 417 | + ], # Store lines for event processing |
| 418 | + json=lambda: json.loads(response_data["text"]) |
| 419 | + if response_data["text"] |
| 420 | + else {}, |
| 421 | + ) |
386 | 422 | else:
|
387 | 423 | response = self._session.request(
|
388 | 424 | api.method.value,
|
@@ -429,14 +465,16 @@ def _call_api_internal(
|
429 | 465 | response,
|
430 | 466 | )
|
431 | 467 | if api.consumes == EVENT_STREAM and api.produces == EVENT_STREAM:
|
432 |
| - for line in response.iter_lines(decode_unicode=True): |
433 |
| - if not line: |
434 |
| - continue |
435 |
| - if not line.startswith("data: "): |
436 |
| - raise ErrorCode.UNABLE_TO_DESERIALIZE.exception_with_parameters( |
437 |
| - line |
438 |
| - ) |
439 |
| - events.append(json.loads(line.split("data: ")[1])) |
| 468 | + # Process event stream using stored lines from the streaming response |
| 469 | + if hasattr(response, "_stream_lines"): |
| 470 | + for line in response._stream_lines: |
| 471 | + if not line: |
| 472 | + continue |
| 473 | + if not line.startswith("data: "): |
| 474 | + raise ErrorCode.UNABLE_TO_DESERIALIZE.exception_with_parameters( |
| 475 | + line |
| 476 | + ) |
| 477 | + events.append(json.loads(line.split("data: ")[1])) |
440 | 478 | if text_response:
|
441 | 479 | response_ = response.text
|
442 | 480 | else:
|
|
0 commit comments