Skip to content

Commit 4dd91ee

Browse files
arekay-nvattafosu
authored andcommitted
[fix] Handle case with string response (#155)
* Handle case with string response Handles the case where the response is a single string, not a list - needed to handle AMD submission which wasn't calculating TPOT without the fix. --------- Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com>
1 parent b04d7d3 commit 4dd91ee

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

src/inference_endpoint/metrics/reporter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,12 @@ def derive_TPOT(
10751075
output_sequence, reasoning_sequence = output_sequence_from_data(
10761076
data_bytes, join_chunks=False
10771077
)
1078+
if isinstance(output_sequence, str):
1079+
output_sequence = [output_sequence]
10781080
if not isinstance(output_sequence, list):
1081+
logging.warning(
1082+
f"Output sequence for sample {sample_uuid} is not a list but {type(output_sequence)}: {output_sequence}"
1083+
)
10791084
continue
10801085

10811086
all_chunks = output_sequence

tests/unit/metrics/test_reporter.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,88 @@ def test_derive_tpot(events_db, sample_uuids, fake_outputs, tokenizer):
7575
assert all(tpot == expected_tpot2 for tpot in tpot2)
7676

7777

78+
def test_derive_tpot_with_string_output(tmp_path, sample_uuids, tokenizer):
79+
"""Test that derive_TPOT handles a plain string output gracefully.
80+
81+
A single-string output has only one chunk, so TPOT cannot be computed.
82+
The reporter should not raise an exception and should return None.
83+
"""
84+
test_db = str(tmp_path / "test_string_output.db")
85+
uuid1 = sample_uuids(1)
86+
87+
with sqlite3_cursor(test_db) as (cursor, conn):
88+
cursor.execute(
89+
"CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)"
90+
)
91+
cursor.executemany(
92+
"INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)",
93+
[
94+
("", SessionEvent.TEST_STARTED.value, 5000, b""),
95+
(uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""),
96+
(uuid1, SampleEvent.FIRST_CHUNK.value, 10010, b""),
97+
(
98+
uuid1,
99+
SampleEvent.COMPLETE.value,
100+
10211,
101+
orjson.dumps({"output": "the final answer"}),
102+
),
103+
("", SessionEvent.TEST_ENDED.value, 10300, b""),
104+
],
105+
)
106+
conn.commit()
107+
108+
with MetricsReporter(test_db) as reporter:
109+
tpot_rows = reporter.derive_TPOT(tokenizer)
110+
111+
# A single-string output produces only 1 chunk — TPOT requires at least 2
112+
assert tpot_rows is None
113+
114+
115+
def test_derive_tpot_string_output_with_list_reasoning(
116+
tmp_path, sample_uuids, tokenizer
117+
):
118+
"""Test that derive_TPOT computes TPOT when string output is paired with a list reasoning sequence.
119+
120+
The fix wraps string outputs into a single-element list so they can be combined with
121+
reasoning chunks. Without the fix, the string output causes the sample to be silently
122+
skipped before reasoning is considered, so TPOT returns None even though there are
123+
enough chunks (output + reasoning) to compute it.
124+
"""
125+
test_db = str(tmp_path / "test_string_output_with_reasoning.db")
126+
uuid1 = sample_uuids(1)
127+
128+
with sqlite3_cursor(test_db) as (cursor, conn):
129+
cursor.execute(
130+
"CREATE TABLE IF NOT EXISTS events (sample_uuid VARCHAR(32), event_type VARCHAR(32), timestamp_ns INTEGER, data BLOB)"
131+
)
132+
cursor.executemany(
133+
"INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)",
134+
[
135+
("", SessionEvent.TEST_STARTED.value, 5000, b""),
136+
(uuid1, SessionEvent.LOADGEN_ISSUE_CALLED.value, 10000, b""),
137+
(uuid1, SampleEvent.FIRST_CHUNK.value, 10010, b""),
138+
(
139+
uuid1,
140+
SampleEvent.COMPLETE.value,
141+
10211,
142+
orjson.dumps(
143+
{"output": "the answer", "reasoning": ["thought step"]}
144+
),
145+
),
146+
("", SessionEvent.TEST_ENDED.value, 10300, b""),
147+
],
148+
)
149+
conn.commit()
150+
151+
with MetricsReporter(test_db) as reporter:
152+
tpot_rows = reporter.derive_TPOT(tokenizer)
153+
154+
# String output ("the answer") + list reasoning (["thought step"]) = 2 chunks total,
155+
# which is enough for TPOT computation.
156+
assert tpot_rows is not None
157+
assert len(tpot_rows) == 1
158+
159+
78160
def test_derive_sample_latency(events_db, sample_uuids):
79161
uuid1 = sample_uuids(1)
80162
uuid2 = sample_uuids(2)

0 commit comments

Comments
 (0)