Skip to content

Commit b723356

Browse files
committed
test: refactor tests for readability and maintainability
- Convert repetitive test methods to parametrized tests - Split test_reporter.py into focused modules: test_reporter_report.py, test_reporter_rollup.py, test_reporter_stop_perf.py - Consolidate duplicate test patterns across test suites - Improve test organization in transforms, extractor, types, record tests
1 parent f2ec2f8 commit b723356

26 files changed

+1721
-2180
lines changed

tests/unit/async_utils/services/event_logger/test_event_logger.py

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,17 @@ def _make_stub(*args, **kwargs) -> tuple[StubEventLoggerService, list[FakeWriter
118118

119119
@pytest.mark.unit
120120
class TestIsErrorEvent:
121-
def test_error_event_types(self):
122-
for et in ErrorEventType:
123-
assert _is_error_event(_record(et)) is True
124-
125-
def test_session_events_are_not_errors(self):
126-
for et in SessionEventType:
127-
assert _is_error_event(_record(et)) is False
128-
129-
def test_sample_events_are_not_errors(self):
130-
for et in SampleEventType:
131-
assert _is_error_event(_record(et)) is False
121+
@pytest.mark.parametrize(
122+
"case_desc, event_type_class, expected",
123+
[
124+
("error events", ErrorEventType, True),
125+
("session events", SessionEventType, False),
126+
("sample events", SampleEventType, False),
127+
],
128+
)
129+
def test_is_error_event(self, case_desc, event_type_class, expected):
130+
for et in event_type_class:
131+
assert _is_error_event(_record(et)) is expected
132132

133133

134134
# ---------------------------------------------------------------------------
@@ -181,7 +181,7 @@ async def test_multiple_batches_accumulate(self):
181181
@pytest.mark.unit
182182
class TestShutdownBehavior:
183183
@pytest.mark.asyncio
184-
async def test_session_ended_triggers_flush_and_close(self):
184+
async def test_ended_triggers_flush_close(self):
185185
service, writers = _make_stub()
186186
await service.process([_record(SessionEventType.ENDED, ts=100)])
187187
for writer in writers:
@@ -216,7 +216,7 @@ async def test_events_after_ended_are_dropped(self):
216216
assert len(new_writer.written) == 0
217217

218218
@pytest.mark.asyncio
219-
async def test_non_error_events_after_ended_in_same_batch_dropped(self):
219+
async def test_non_errors_after_ended_same_batch(self):
220220
service, writers = _make_stub()
221221
await service.process(
222222
[
@@ -230,7 +230,7 @@ async def test_non_error_events_after_ended_in_same_batch_dropped(self):
230230
assert writer.written[0].event_type == SessionEventType.ENDED
231231

232232
@pytest.mark.asyncio
233-
async def test_error_events_after_ended_in_same_batch_still_written(self):
233+
async def test_errors_after_ended_same_batch_kept(self):
234234
service, writers = _make_stub()
235235
err_data = ErrorData(error_type="TestError", error_message="boom")
236236
await service.process(
@@ -244,7 +244,7 @@ async def test_error_events_after_ended_in_same_batch_still_written(self):
244244
assert writer.written[1].event_type == ErrorEventType.GENERIC
245245

246246
@pytest.mark.asyncio
247-
async def test_error_events_after_ended_in_later_batch_dropped(self):
247+
async def test_errors_after_ended_later_batch(self):
248248
"""Error events are only kept in the same batch as ENDED.
249249
250250
After the batch containing ENDED completes, writers are closed and
@@ -519,47 +519,42 @@ async def test_full_lifecycle(self, tmp_path):
519519
@pytest.mark.unit
520520
class TestEdgeCases:
521521
@pytest.mark.asyncio
522-
async def test_all_error_event_types_are_recognized(self):
523-
service, writers = _make_stub()
524-
error_records = [_record(et, ts=i) for i, et in enumerate(ErrorEventType)]
525-
await service.process(error_records)
526-
for writer in writers:
527-
assert len(writer.written) == len(list(ErrorEventType))
528-
529-
@pytest.mark.asyncio
530-
async def test_all_session_event_types_written(self):
531-
service, writers = _make_stub()
532-
session_records = [_record(et, ts=i) for i, et in enumerate(SessionEventType)]
533-
await service.process(session_records)
534-
for writer in writers:
535-
# All session events should be written
536-
# (ENDED is among them but everything in the batch up to and including ENDED is written)
537-
assert len(writer.written) == len(list(SessionEventType))
538-
539-
@pytest.mark.asyncio
540-
async def test_all_sample_event_types_written(self):
522+
@pytest.mark.parametrize(
523+
"case_desc, event_type_class, use_uuid",
524+
[
525+
("error events", ErrorEventType, False),
526+
("session events", SessionEventType, False),
527+
("sample events", SampleEventType, True),
528+
],
529+
)
530+
async def test_all_event_types_written(self, case_desc, event_type_class, use_uuid):
541531
service, writers = _make_stub()
542-
sample_records = [
543-
_record(et, uuid="s1", ts=i) for i, et in enumerate(SampleEventType)
532+
records = [
533+
_record(et, uuid="s1" if use_uuid else "", ts=i)
534+
for i, et in enumerate(event_type_class)
544535
]
545-
await service.process(sample_records)
546-
for writer in writers:
547-
assert len(writer.written) == len(list(SampleEventType))
548-
549-
@pytest.mark.asyncio
550-
async def test_record_with_no_data(self):
551-
service, writers = _make_stub()
552-
await service.process([_record(SampleEventType.ISSUED, uuid="s1")])
536+
await service.process(records)
553537
for writer in writers:
554-
assert writer.written[0].data is None
538+
assert len(writer.written) == len(list(event_type_class))
555539

556540
@pytest.mark.asyncio
557-
async def test_record_with_error_data(self):
541+
@pytest.mark.parametrize(
542+
"case_desc, event_type, data, expected_data",
543+
[
544+
("no data", SampleEventType.ISSUED, None, None),
545+
(
546+
"error data",
547+
ErrorEventType.CLIENT,
548+
ErrorData(error_type="SomeError", error_message="detail"),
549+
ErrorData(error_type="SomeError", error_message="detail"),
550+
),
551+
],
552+
)
553+
async def test_record_data(self, case_desc, event_type, data, expected_data):
558554
service, writers = _make_stub()
559-
err = ErrorData(error_type="SomeError", error_message="detail")
560-
await service.process([_record(ErrorEventType.CLIENT, data=err)])
555+
await service.process([_record(event_type, uuid="s1", data=data)])
561556
for writer in writers:
562-
assert writer.written[0].data == err
557+
assert writer.written[0].data == expected_data
563558

564559
@pytest.mark.asyncio
565560
async def test_large_batch(self):

tests/unit/async_utils/services/event_logger/test_sql_writer.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,32 +47,38 @@ def _record(event_type, uuid="", ts=0, data=None):
4747

4848
@pytest.mark.unit
4949
class TestRecordToRow:
50-
def test_sample_event_topic(self):
51-
row = _record_to_row(_record(SampleEventType.ISSUED, uuid="s1", ts=1000))
52-
assert row.event_type == "sample.issued"
53-
assert row.sample_uuid == "s1"
54-
assert row.timestamp_ns == 1000
55-
56-
def test_session_event_topic(self):
57-
row = _record_to_row(_record(SessionEventType.ENDED, ts=42))
58-
assert row.event_type == "session.ended"
59-
assert row.sample_uuid == ""
60-
assert row.timestamp_ns == 42
61-
62-
def test_error_event_topic(self):
63-
row = _record_to_row(_record(ErrorEventType.GENERIC, ts=99))
64-
assert row.event_type == "error.generic"
65-
66-
def test_data_is_json_encoded(self):
67-
err = ErrorData(error_type="TestError", error_message="boom")
68-
row = _record_to_row(_record(SampleEventType.COMPLETE, data=err))
69-
decoded = msgspec.json.decode(row.data)
70-
assert "TestError" in str(decoded)
71-
72-
def test_none_data_encodes_to_null(self):
73-
row = _record_to_row(_record(SampleEventType.ISSUED))
50+
@pytest.mark.parametrize(
51+
"case_desc, event_type, uuid, ts, expected_topic",
52+
[
53+
("sample event", SampleEventType.ISSUED, "s1", 1000, "sample.issued"),
54+
("session event", SessionEventType.ENDED, "", 42, "session.ended"),
55+
("error event", ErrorEventType.GENERIC, "", 99, "error.generic"),
56+
],
57+
)
58+
def test_topic_and_fields(self, case_desc, event_type, uuid, ts, expected_topic):
59+
row = _record_to_row(_record(event_type, uuid=uuid, ts=ts))
60+
assert row.event_type == expected_topic
61+
assert row.sample_uuid == uuid
62+
assert row.timestamp_ns == ts
63+
64+
@pytest.mark.parametrize(
65+
"case_desc, data, check_str",
66+
[
67+
(
68+
"error data",
69+
ErrorData(error_type="TestError", error_message="boom"),
70+
"TestError",
71+
),
72+
("none data", None, None),
73+
],
74+
)
75+
def test_data_encoding(self, case_desc, data, check_str):
76+
row = _record_to_row(_record(SampleEventType.COMPLETE, data=data))
7477
decoded = msgspec.json.decode(row.data)
75-
assert decoded is None
78+
if check_str is not None:
79+
assert check_str in str(decoded)
80+
else:
81+
assert decoded is None
7682

7783

7884
# ---------------------------------------------------------------------------

tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ async def test_all_timing_metrics_full_lifecycle(self):
270270
assert m["chunk_delta_ns"] == 1000
271271

272272
@pytest.mark.asyncio
273-
async def test_chunk_delta_not_emitted_without_last_recv(self):
273+
async def test_chunk_delta_needs_last_recv(self):
274274
"""RECV_NON_FIRST without prior RECV_FIRST: no chunk_delta emitted."""
275275
emitter = FakeEmitter()
276276
agg = StubAggregator(emitter)
@@ -285,7 +285,7 @@ async def test_chunk_delta_not_emitted_without_last_recv(self):
285285
assert row.last_recv_ns is None # No recv events yet
286286

287287
@pytest.mark.asyncio
288-
async def test_request_duration_not_emitted_without_client_send(self):
288+
async def test_req_duration_needs_client_send(self):
289289
"""CLIENT_RESP_DONE without CLIENT_SEND: no request_duration."""
290290
emitter = FakeEmitter()
291291
agg = StubAggregator(emitter)
@@ -326,7 +326,7 @@ async def test_issued_stores_prompt_text(self):
326326
assert row.prompt_text == "What is AI?"
327327

328328
@pytest.mark.asyncio
329-
async def test_issued_with_token_ids_emits_isl_directly(self):
329+
async def test_token_ids_emit_isl_directly(self):
330330
"""SGLang path: PromptData with token_ids emits ISL = len(token_ids)
331331
without tokenization."""
332332
emitter = FakeEmitter()

tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -39,53 +39,47 @@ def test_is_msgspec_struct(self):
3939
row = SampleRow("s1")
4040
assert isinstance(row, msgspec.Struct)
4141

42-
def test_ttft(self):
43-
row = SampleRow("s1")
44-
row.issued_ns = 1000
45-
row.recv_first_ns = 2500
46-
assert row.ttft_ns() == 1500
47-
48-
def test_ttft_returns_none_without_issued(self):
49-
row = SampleRow("s1")
50-
row.recv_first_ns = 2500
51-
assert row.ttft_ns() is None
52-
53-
def test_ttft_returns_none_without_recv_first(self):
54-
row = SampleRow("s1")
55-
row.issued_ns = 1000
56-
assert row.ttft_ns() is None
57-
58-
def test_sample_latency(self):
59-
row = SampleRow("s1")
60-
row.issued_ns = 1000
61-
row.complete_ns = 5000
62-
assert row.sample_latency_ns() == 4000
63-
64-
def test_sample_latency_returns_none_without_issued(self):
65-
row = SampleRow("s1")
66-
row.complete_ns = 5000
67-
assert row.sample_latency_ns() is None
68-
69-
def test_sample_latency_returns_none_without_complete(self):
70-
row = SampleRow("s1")
71-
row.issued_ns = 1000
72-
assert row.sample_latency_ns() is None
73-
74-
def test_request_duration(self):
75-
row = SampleRow("s1")
76-
row.client_send_ns = 100
77-
row.client_resp_done_ns = 600
78-
assert row.request_duration_ns() == 500
79-
80-
def test_request_duration_returns_none_without_send(self):
81-
row = SampleRow("s1")
82-
row.client_resp_done_ns = 600
83-
assert row.request_duration_ns() is None
84-
85-
def test_request_duration_returns_none_without_resp_done(self):
86-
row = SampleRow("s1")
87-
row.client_send_ns = 100
88-
assert row.request_duration_ns() is None
42+
@pytest.mark.parametrize(
43+
"case_desc, issued, recv_first, expected",
44+
[
45+
("both set", 1000, 2500, 1500),
46+
("no issued", None, 2500, None),
47+
("no recv_first", 1000, None, None),
48+
],
49+
)
50+
def test_ttft(self, case_desc, issued, recv_first, expected):
51+
row = SampleRow("s1")
52+
row.issued_ns = issued
53+
row.recv_first_ns = recv_first
54+
assert row.ttft_ns() == expected
55+
56+
@pytest.mark.parametrize(
57+
"case_desc, issued, complete, expected",
58+
[
59+
("both set", 1000, 5000, 4000),
60+
("no issued", None, 5000, None),
61+
("no complete", 1000, None, None),
62+
],
63+
)
64+
def test_sample_latency(self, case_desc, issued, complete, expected):
65+
row = SampleRow("s1")
66+
row.issued_ns = issued
67+
row.complete_ns = complete
68+
assert row.sample_latency_ns() == expected
69+
70+
@pytest.mark.parametrize(
71+
"case_desc, send, resp_done, expected",
72+
[
73+
("both set", 100, 600, 500),
74+
("no send", None, 600, None),
75+
("no resp_done", 100, None, None),
76+
],
77+
)
78+
def test_request_duration(self, case_desc, send, resp_done, expected):
79+
row = SampleRow("s1")
80+
row.client_send_ns = send
81+
row.client_resp_done_ns = resp_done
82+
assert row.request_duration_ns() == expected
8983

9084
def test_output_text_empty(self):
9185
row = SampleRow("s1")

0 commit comments

Comments
 (0)