Skip to content

Commit b9620bf

Browse files
authored
Fix reading huge (XCom) resposne in TaskSDK task process (#53186)
If you tried to send a large XCom value, it would fail in the task/child process side with this error: > RuntimeError: unable to read full response in child. (We read 36476, but expected 1310046) (The exact number that was able to read dependent on any different factors, like the OS, the current state of the socket and other things. Sometimes it would read up to 256kb fine, othertimes only 35kb as here) This is because the kernel level read-side socket buffer is full, so that was as much as the Supervisor could send. The fix is to read in a loop until we get it all.
1 parent 47bbe55 commit b9620bf

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

task-sdk/src/airflow/sdk/execution_time/comms.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,16 @@ def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[i
234234
length = int.from_bytes(len_bytes, byteorder="big")
235235

236236
buffer = bytearray(length)
237-
nread = self.socket.recv_into(buffer)
238-
if nread != length:
239-
raise RuntimeError(
240-
f"unable to read full response in child. (We read {nread}, but expected {length})"
241-
)
242-
if nread == 0:
243-
raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})")
244-
245-
resp = self.resp_decoder.decode(buffer)
237+
mv = memoryview(buffer)
238+
239+
pos = 0
240+
while pos < length:
241+
nread = self.socket.recv_into(mv[pos:])
242+
if nread == 0:
243+
raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})")
244+
pos += nread
245+
246+
resp = self.resp_decoder.decode(mv)
246247
if maxfds:
247248
return resp, fds or []
248249
return resp

task-sdk/tests/task_sdk/execution_time/test_comms.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from __future__ import annotations
1919

20+
import threading
2021
import uuid
2122
from socket import socketpair
2223

@@ -82,3 +83,32 @@ def test_recv_StartupDetails(self):
8283
assert msg.dag_rel_path == "/dev/null"
8384
assert msg.bundle_info == BundleInfo(name="any-name", version="any-version")
8485
assert msg.start_date == timezone.datetime(2024, 12, 1, 1)
86+
87+
def test_huge_payload(self):
88+
r, w = socketpair()
89+
90+
msg = {
91+
"type": "XComResult",
92+
"key": "a",
93+
"value": ("a" * 10 * 1024 * 1024) + "b", # A 10mb xcom value
94+
}
95+
96+
w.settimeout(1.0)
97+
bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None))
98+
99+
# Since `sendall` blocks, we need to do the send in another thread, so we can perform the read here
100+
t = threading.Thread(target=w.sendall, args=(len(bytes).to_bytes(4, byteorder="big") + bytes,))
101+
t.start()
102+
103+
decoder = CommsDecoder(socket=r, log=None)
104+
105+
try:
106+
msg = decoder._get_response()
107+
finally:
108+
t.join(2)
109+
110+
assert msg is not None
111+
112+
# It actually failed to read at all for large values, but lets just make sure we get it all
113+
assert len(msg.value) == 10 * 1024 * 1024 + 1
114+
assert msg.value[-1] == "b"

0 commit comments

Comments
 (0)