Skip to content

Commit fa7734f

Browse files
committed
Enforce strict data types for record.py. Bump default worker init timeout to 60
1 parent daf0794 commit fa7734f

File tree

8 files changed

+53
-56
lines changed

8 files changed

+53
-56
lines changed

scripts/zmq_pubsub_demo.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,11 @@ async def publish_test_events(publisher) -> None:
153153
event_type=SampleEventType.RECV_FIRST,
154154
timestamp_ns=10010,
155155
sample_uuid=uuid1,
156-
data={"ttft_ms": 10.0},
157156
),
158157
EventRecord(
159158
event_type=SampleEventType.RECV_FIRST,
160159
timestamp_ns=10190,
161160
sample_uuid=uuid2,
162-
data={"ttft_ms": 187.0},
163161
),
164162
EventRecord(
165163
event_type=SampleEventType.RECV_NON_FIRST,
@@ -190,7 +188,7 @@ async def publish_test_events(publisher) -> None:
190188
event_type=SampleEventType.COMPLETE,
191189
timestamp_ns=10211,
192190
sample_uuid=uuid1,
193-
data={"tokens": 50},
191+
data="Hello world",
194192
),
195193
EventRecord(
196194
event_type=SampleEventType.RECV_NON_FIRST,
@@ -211,7 +209,7 @@ async def publish_test_events(publisher) -> None:
211209
event_type=SampleEventType.COMPLETE,
212210
timestamp_ns=10219,
213211
sample_uuid=uuid2,
214-
data={"tokens": 75},
212+
data="Sample output for uuid2",
215213
),
216214
]
217215

src/inference_endpoint/async_utils/services/event_logger/sql_writer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pathlib import Path
1919

20-
import orjson
20+
import msgspec
2121
from inference_endpoint.core.record import EventRecord
2222
from sqlalchemy import BigInteger, Integer, LargeBinary, String, create_engine
2323
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker
@@ -55,17 +55,13 @@ class EventRowModel(Base):
5555

5656

5757
def _record_to_row(record: EventRecord) -> EventRowModel:
58-
"""Convert an EventRecord to an EventRowModel using EventType topic strings."""
59-
data_bytes = b""
60-
if record.data:
61-
data_bytes = orjson.dumps(record.data)
6258
# event_type.topic is set by EventTypeMeta on each enum member
6359
topic = record.event_type.topic # type: ignore[attr-defined]
6460
return EventRowModel(
6561
sample_uuid=record.sample_uuid,
6662
event_type=topic,
6763
timestamp_ns=record.timestamp_ns,
68-
data=data_bytes,
64+
data=msgspec.json.encode(record.data),
6965
)
7066

7167

src/inference_endpoint/async_utils/transport/protocol.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
decode_event_record,
3636
encode_event_record,
3737
)
38-
from inference_endpoint.core.types import Query, QueryResult, StreamChunk
38+
from inference_endpoint.core.types import ErrorData, Query, QueryResult, StreamChunk
3939

4040
if TYPE_CHECKING:
4141
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
@@ -341,14 +341,12 @@ def _on_readable(self) -> None:
341341
try:
342342
event_record = decode_event_record(payload)
343343
except msgspec.DecodeError as e:
344-
# Record an error instead
345-
# TODO: Make `data` field more rigidly typed
346344
event_record = EventRecord(
347345
event_type=ErrorEventType.GENERIC,
348-
data={
349-
"error_type": "msgspec.DecodeError",
350-
"error_message": str(e),
351-
},
346+
data=ErrorData(
347+
error_type="msgspec.DecodeError",
348+
error_message=str(e),
349+
),
352350
)
353351
records.append(event_record)
354352
except StopIteration:

src/inference_endpoint/core/record.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import msgspec
2121

22+
from .types import OUTPUT_TYPE, ErrorData
23+
2224
TOPIC_FRAME_SIZE: Final[int] = 40
2325
"""int: Fixed bytesize for the encoded topic string. PUB messages will be prefixed by a
2426
topic string corresponding to the EventType. This topic will be null-padded to this fixed
@@ -151,7 +153,7 @@ class EventRecord(msgspec.Struct, kw_only=True): # type: ignore[call-arg]
151153
event_type: EventType
152154
timestamp_ns: int = msgspec.field(default_factory=time.monotonic_ns)
153155
sample_uuid: str = ""
154-
data: dict[str, Any] = msgspec.field(default_factory=dict)
156+
data: OUTPUT_TYPE | ErrorData | None = None
155157

156158

157159
_ENCODER = msgspec.msgpack.Encoder(enc_hook=EventType.encode_hook)

src/inference_endpoint/core/types.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ class QueryStatus(Enum):
5151
OUTPUT_ELEM_TYPE = str | tuple[str, ...]
5252
"""Type for a single output or reasoning value: string (non-streaming) or tuple of strings (streaming)."""
5353

54-
_OUTPUT_DICT_TYPE = dict[str, str | list[str]]
55-
_OUTPUT_RESULT_TYPE = str | tuple[str, ...] | _OUTPUT_DICT_TYPE | None
56-
5754

5855
class TextModelOutput(msgspec.Struct, tag=True, kw_only=True): # type: ignore[call-arg]
5956
"""Structured output from a text model.

src/inference_endpoint/endpoint_client/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class HTTPClientConfig:
7171
cpu_affinity: AffinityPlan | None = None
7272

7373
# Worker lifecycle timeouts
74-
worker_initialization_timeout: float = 40.0 # init
74+
worker_initialization_timeout: float = 60.0 # init
7575
worker_graceful_shutdown_wait: float = 0.5 # post-run
7676
worker_force_kill_timeout: float = 0.5 # post-run
7777

tests/unit/async_utils/test_event_publisher.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ async def test_publish_sends_data_on_ipc_socket(
151151
record = EventRecord(
152152
event_type=SessionEventType.STARTED,
153153
sample_uuid="",
154-
data={"manual_socket": True},
154+
data="manual_socket_test",
155155
)
156156
event_publisher_service.publish(record)
157157
# Yield so the publisher's event loop can drain the send buffer
@@ -168,7 +168,7 @@ async def test_publish_sends_data_on_ipc_socket(
168168
assert topic_bytes == b"session.started"
169169
rec = decode_event_record(bytes(payload))
170170
assert rec.event_type.value == SessionEventType.STARTED.value
171-
assert rec.data == {"manual_socket": True}
171+
assert rec.data == "manual_socket_test"
172172
# Socket is closed by ManagedZMQContext.cleanup() in ev_pub_zmq_context fixture teardown.
173173

174174
def test_singleton_returns_same_instance(
@@ -190,15 +190,15 @@ async def test_publish_session_event_received_by_subscriber(
190190
record = EventRecord(
191191
event_type=SessionEventType.STARTED,
192192
sample_uuid="",
193-
data={"key": "value"},
193+
data="value",
194194
)
195195
event_publisher_service.publish(record)
196196
await asyncio.sleep(0.05) # Let publisher drain send buffer
197197
await asyncio.wait_for(received_event.wait(), timeout=_WAIT_RECORDS_TIMEOUT)
198198
assert len(collecting_subscriber.received) == 1
199199
rec = collecting_subscriber.received[0]
200200
assert rec.event_type.value == SessionEventType.STARTED.value
201-
assert rec.data == {"key": "value"}
201+
assert rec.data == "value"
202202

203203
@pytest.mark.asyncio
204204
async def test_publish_sample_event_received_by_subscriber(
@@ -210,7 +210,7 @@ async def test_publish_sample_event_received_by_subscriber(
210210
record = EventRecord(
211211
event_type=SampleEventType.COMPLETE,
212212
sample_uuid="sample-1",
213-
data={"latency_ns": 42},
213+
data="sample output",
214214
)
215215
event_publisher_service.publish(record)
216216
await asyncio.sleep(0.05) # Let publisher drain send buffer
@@ -219,7 +219,7 @@ async def test_publish_sample_event_received_by_subscriber(
219219
rec = collecting_subscriber.received[0]
220220
assert rec.event_type.value == SampleEventType.COMPLETE.value
221221
assert rec.sample_uuid == "sample-1"
222-
assert rec.data == {"latency_ns": 42}
222+
assert rec.data == "sample output"
223223

224224
@pytest.mark.asyncio
225225
async def test_multiple_events_received_in_order(
@@ -232,12 +232,11 @@ async def test_multiple_events_received_in_order(
232232
record = EventRecord(
233233
event_type=SampleEventType.ISSUED,
234234
sample_uuid=f"sample-{i}",
235-
data={"seq": i},
236235
)
237236
event_publisher_service.publish(record)
238237
await asyncio.sleep(0.05) # Small delay between events (demo-style)
239238
await asyncio.wait_for(received_event.wait(), timeout=_WAIT_RECORDS_TIMEOUT)
240239
assert len(collecting_subscriber.received) == 3
241240
for i in range(3):
242241
assert collecting_subscriber.received[i].sample_uuid == f"sample-{i}"
243-
assert collecting_subscriber.received[i].data.get("seq") == i
242+
assert collecting_subscriber.received[i].data is None

tests/unit/core/test_record.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import time
1919

20-
import msgspec
2120
import pytest
2221
from inference_endpoint.core.record import (
2322
TOPIC_FRAME_SIZE,
@@ -29,6 +28,7 @@
2928
decode_event_record,
3029
encode_event_record,
3130
)
31+
from inference_endpoint.core.types import ErrorData, TextModelOutput
3232

3333

3434
class TestEventType:
@@ -64,8 +64,7 @@ def test_construction_with_only_event_type_uses_defaults(self):
6464
after = time.monotonic_ns()
6565
assert before <= record.timestamp_ns <= after
6666
assert record.sample_uuid == ""
67-
assert record.data == {}
68-
assert isinstance(record.data, dict)
67+
assert record.data is None
6968

7069

7170
class TestEncodeEventRecord:
@@ -76,17 +75,16 @@ def test_returns_tuple_of_topic_bytes_padded_and_payload_bytes_with_valid_msgpac
7675
record = EventRecord(
7776
event_type=SampleEventType.ISSUED,
7877
sample_uuid="test-uuid",
79-
data={"key": "value"},
78+
data="test-output",
8079
)
8180
topic_bytes, payload = encode_event_record(record)
8281
assert isinstance(topic_bytes, bytes)
8382
assert len(topic_bytes) == TOPIC_FRAME_SIZE
8483
assert topic_bytes.rstrip(b"\x00") == b"sample.issued"
8584
assert isinstance(payload, bytes)
86-
decoded = msgspec.msgpack.decode(payload)
87-
assert isinstance(decoded, dict)
88-
assert decoded.get("sample_uuid") == "test-uuid"
89-
assert decoded.get("data") == {"key": "value"}
85+
decoded = decode_event_record(payload)
86+
assert decoded.sample_uuid == "test-uuid"
87+
assert decoded.data == "test-output"
9088

9189
def test_topic_bytes_padded_matches_event_type_for_session_sample_error(self):
9290
"""Topic is null-padded to TOPIC_FRAME_SIZE for single-frame ZMQ sends."""
@@ -106,37 +104,55 @@ def test_session_event_round_trips_with_all_fields(self):
106104
record = EventRecord(
107105
event_type=SessionEventType.STARTED,
108106
sample_uuid="sess-1",
109-
data={"session_id": "abc"},
110107
)
111108
_, payload = encode_event_record(record)
112109
decoded = decode_event_record(payload)
113110
assert decoded.event_type.topic == SessionEventType.STARTED.topic
114111
assert decoded.sample_uuid == "sess-1"
115-
assert decoded.data == {"session_id": "abc"}
112+
assert decoded.data is None
116113
assert isinstance(decoded.timestamp_ns, int)
117114
assert decoded.timestamp_ns == record.timestamp_ns
118115

119-
def test_sample_event_round_trips(self):
116+
def test_sample_event_round_trips_with_output(self):
120117
record = EventRecord(
121118
event_type=SampleEventType.COMPLETE,
122119
sample_uuid="sample-42",
123-
data={"latency_ns": 1000},
120+
data="output text",
124121
)
125122
_, payload = encode_event_record(record)
126123
decoded = decode_event_record(payload)
127124
assert decoded.event_type.topic == SampleEventType.COMPLETE.topic
128125
assert decoded.sample_uuid == "sample-42"
129-
assert decoded.data == {"latency_ns": 1000}
126+
assert decoded.data == "output text"
130127

131-
def test_error_event_round_trips_with_defaults(self):
128+
def test_sample_event_round_trips_with_text_model_output(self):
129+
record = EventRecord(
130+
event_type=SampleEventType.COMPLETE,
131+
sample_uuid="sample-42",
132+
data=TextModelOutput(output="out", reasoning="reason"),
133+
)
134+
_, payload = encode_event_record(record)
135+
decoded = decode_event_record(payload)
136+
assert decoded.event_type.topic == SampleEventType.COMPLETE.topic
137+
assert decoded.sample_uuid == "sample-42"
138+
assert isinstance(decoded.data, TextModelOutput)
139+
assert decoded.data.output == "out"
140+
assert decoded.data.reasoning == "reason"
141+
142+
def test_error_event_round_trips_with_error_data(self):
132143
record = EventRecord(
133144
event_type=ErrorEventType.LOADGEN,
134-
data={"message": "error details"},
145+
data=ErrorData(
146+
error_type="LoadgenError",
147+
error_message="error details",
148+
),
135149
)
136150
_, payload = encode_event_record(record)
137151
decoded = decode_event_record(payload)
138152
assert decoded.event_type.topic == ErrorEventType.LOADGEN.topic
139-
assert decoded.data == {"message": "error details"}
153+
assert isinstance(decoded.data, ErrorData)
154+
assert decoded.data.error_type == "LoadgenError"
155+
assert decoded.data.error_message == "error details"
140156
assert decoded.sample_uuid == ""
141157

142158
def test_record_with_only_event_type_round_trips_with_defaults(self):
@@ -145,7 +161,7 @@ def test_record_with_only_event_type_round_trips_with_defaults(self):
145161
decoded = decode_event_record(payload)
146162
assert decoded.event_type.topic == SessionEventType.ENDED.topic
147163
assert decoded.sample_uuid == ""
148-
assert decoded.data == {}
164+
assert decoded.data is None
149165
assert decoded.timestamp_ns > 0
150166

151167
def test_explicit_timestamp_ns_preserved_round_trip(self):
@@ -157,12 +173,3 @@ def test_explicit_timestamp_ns_preserved_round_trip(self):
157173
_, payload = encode_event_record(record)
158174
decoded = decode_event_record(payload)
159175
assert decoded.timestamp_ns == ts
160-
161-
def test_nested_and_list_data_round_trips(self):
162-
record = EventRecord(
163-
event_type=SampleEventType.TRANSPORT_RECV,
164-
data={"nested": {"a": 1}, "list": [1, 2, 3]},
165-
)
166-
_, payload = encode_event_record(record)
167-
decoded = decode_event_record(payload)
168-
assert decoded.data == {"nested": {"a": 1}, "list": [1, 2, 3]}

0 commit comments

Comments
 (0)