Skip to content

Commit faafea5

Browse files
authored
fix(http): ensure consistent string output from response transforms (#34)
- Convert all transform outputs to strings consistently - Add test coverage for integer response values - Add test coverage for list response values - Fix spacing in message logging format The changes ensure that response transforms always return string values, fixing potential type mismatches when handling numeric or list responses.
1 parent 2dc5778 commit faafea5

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

rigging/generator/http.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,12 @@ def parse_response_body(self, data: str) -> str:
190190
matches = [match.value for match in jsonpath_expr.find(result)]
191191
if len(matches) == 0:
192192
raise Exception(f"No matches found for JSONPath: {transform.pattern} from {result}")
193-
result = json.dumps(matches) if len(matches) > 1 else matches[0]
193+
result = json.dumps(matches) if len(matches) > 1 else str(matches[0])
194194

195195
elif transform.type == "regex":
196196
matches = re.findall(_to_str(transform.pattern), result)
197197
matches = [str(match) for match in matches]
198-
result = json.dumps(matches) if len(matches) > 1 else matches[0]
198+
result = json.dumps(matches) if len(matches) > 1 else str(matches[0])
199199

200200
return result
201201

@@ -338,7 +338,7 @@ async def generate_messages(
338338
generated = await asyncio.gather(*coros)
339339

340340
for i, (_messages, response) in enumerate(zip(messages, generated)):
341-
trace_messages(_messages, f"Messages {i+1}/{len(messages)}")
342-
trace_messages([response], f"Response {i+1}/{len(messages)}")
341+
trace_messages(_messages, f"Messages {i + 1}/{len(messages)}")
342+
trace_messages([response], f"Response {i + 1}/{len(messages)}")
343343

344344
return generated

tests/test_http_spec.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,37 @@ def test_custom_header_templates() -> None:
212212
assert headers["Authorization"] == "Bearer test-key"
213213
assert headers["X-Request-ID"] == "default-id"
214214
assert headers["X-Model-Version"] == "test-model-v1"
215+
216+
217+
def test_parse_response_body_int() -> None:
218+
spec = HTTPSpec(
219+
request={
220+
"url": "https://api.example.com/v1/chat",
221+
"method": "POST",
222+
"headers": {"Authorization": "Bearer {{api_key}}"},
223+
"transforms": [{"type": "json", "pattern": {"model": "$model", "messages": "$messages"}}],
224+
},
225+
response={
226+
"valid_status_codes": [200, 201],
227+
"transforms": [{"type": "jsonpath", "pattern": "$.int_value"}],
228+
},
229+
)
230+
result = spec.parse_response_body('{"int_value": 42}')
231+
assert result == "42"
232+
233+
234+
def test_parse_response_body_list() -> None:
235+
spec = HTTPSpec(
236+
request={
237+
"url": "https://api.example.com/v1/chat",
238+
"method": "POST",
239+
"headers": {"Authorization": "Bearer {{api_key}}"},
240+
"transforms": [{"type": "json", "pattern": {"model": "$model", "messages": "$messages"}}],
241+
},
242+
response={
243+
"valid_status_codes": [200, 201],
244+
"transforms": [{"type": "jsonpath", "pattern": "foo[*].baz"}],
245+
},
246+
)
247+
result = spec.parse_response_body('{"foo": [{"baz": 1}, {"baz": 2}]}')
248+
assert result == "[1, 2]"

0 commit comments

Comments
 (0)