Skip to content

Commit 12f6fbd

Browse files
authored
Add stream method for GCSRemoteIO (#59753)
* Add stream method for GCSRemoteLogIO * Fix TestGCSTaskHandler, add error handling for read * Add test_upload * Open stream outside of _get_log_stream, early return for read if logs is None * Add test_write and test_stream_and_read_methods * Fix mistook import of RawLogStream * Fix mypy error Fix missing mock for get_credentials_and_project_id Fix mypy error Fix test * Fix mypy and unit test Skip 2.11 test * Fix compat test * Fix review comment
1 parent 60b4ed4 commit 12f6fbd

File tree

2 files changed

+307
-24
lines changed

2 files changed

+307
-24
lines changed

providers/google/src/airflow/providers/google/cloud/log/gcs_task_handler.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@
4343
from airflow.utils.log.logging_mixin import LoggingMixin
4444

4545
if TYPE_CHECKING:
46+
from io import TextIOWrapper
47+
4648
from airflow.models.taskinstance import TaskInstance
4749
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
48-
from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
50+
from airflow.utils.log.file_task_handler import LogResponse, RawLogStream, StreamingLogResponse
4951

5052
_DEFAULT_SCOPESS = frozenset(
5153
[
@@ -149,11 +151,26 @@ def no_log_found(exc):
149151
exc, "resp", {}
150152
).get("status") == "404"
151153

152-
def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
153-
messages = []
154-
logs = []
154+
def read(self, relative_path: str, ti: RuntimeTI) -> LogResponse:
155+
messages, log_streams = self.stream(relative_path, ti)
156+
if not log_streams:
157+
return messages, None
158+
159+
logs: list[str] = []
160+
try:
161+
# for each log_stream, exhaust the generator into a string
162+
logs = ["".join(line for line in log_stream) for log_stream in log_streams]
163+
except Exception as e:
164+
if not AIRFLOW_V_3_0_PLUS:
165+
messages.append(f"Unable to read remote log {e}")
166+
167+
return messages, logs
168+
169+
def stream(self, relative_path: str, ti: RuntimeTI) -> StreamingLogResponse:
170+
messages: list[str] = []
171+
log_streams: list[RawLogStream] = []
155172
remote_loc = os.path.join(self.remote_base, relative_path)
156-
uris = []
173+
uris: list[str] = []
157174
bucket, prefix = _parse_gcs_url(remote_loc)
158175
blobs = list(self.client.list_blobs(bucket_or_name=bucket, prefix=prefix))
159176

@@ -164,18 +181,29 @@ def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMes
164181
else:
165182
messages.extend(["Found remote logs:", *[f" * {x}" for x in sorted(uris)]])
166183
else:
167-
return messages, None
184+
return messages, []
168185

169186
try:
170187
for key in sorted(uris):
171188
blob = storage.Blob.from_string(key, self.client)
172-
remote_log = blob.download_as_bytes().decode()
173-
if remote_log:
174-
logs.append(remote_log)
189+
stream = blob.open("r")
190+
log_streams.append(self._get_log_stream(stream))
175191
except Exception as e:
176192
if not AIRFLOW_V_3_0_PLUS:
177193
messages.append(f"Unable to read remote log {e}")
178-
return messages, logs
194+
return messages, log_streams
195+
196+
def _get_log_stream(self, stream: TextIOWrapper) -> RawLogStream:
197+
"""
198+
Yield lines from the given stream.
199+
200+
:param stream: The opened stream to read from.
201+
:yield: Lines of the log file.
202+
"""
203+
try:
204+
yield from stream
205+
finally:
206+
stream.close()
179207

180208

181209
class GCSTaskHandler(FileTaskHandler, LoggingMixin):
@@ -273,7 +301,7 @@ def close(self):
273301
# Mark closed so we don't double write if close is called twice
274302
self.closed = True
275303

276-
def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]:
304+
def _read_remote_logs(self, ti, try_number, metadata=None) -> LogResponse:
277305
# Explicitly getting log relative path is necessary as the given
278306
# task instance might be different than task instance passed in
279307
# in set_context method.
@@ -283,7 +311,7 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInf
283311

284312
if logs is None:
285313
logs = []
286-
if not AIRFLOW_V_3_0_PLUS:
314+
if not AIRFLOW_V_3_0_PLUS and not messages:
287315
messages.append(f"No logs found in GCS; ti={ti}")
288316

289317
return messages, logs

0 commit comments

Comments
 (0)