Skip to content

Commit e5dd1fa

Browse files
committed
fix: correct creation of trace log from api response
1 parent 971b118 commit e5dd1fa

File tree

3 files changed

+76
-3
lines changed

3 files changed

+76
-3
lines changed

parea/helpers.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from cattrs import GenConverter
1212

1313
from parea.constants import ADJECTIVES, NOUNS, TURN_OFF_PAREA_LOGGING
14-
from parea.schemas.models import Completion, PaginatedTraceLogsResponse, TraceLog, TraceLogTree, UpdateLog
14+
from parea.schemas import EvaluationResult, LLMInputs, Message, ModelParams, Role
15+
from parea.schemas.models import Completion, PaginatedTraceLogsResponse, TraceLog, TraceLogAnnotationSchema, TraceLogCommentSchema, TraceLogImage, TraceLogTree, UpdateLog
1516
from parea.utils.universal_encoder import json_dumps
1617

1718

@@ -104,13 +105,54 @@ def structure_float_or_none(obj: Any, cl: type) -> Optional[float]:
104105
converter.register_structure_hook(float, structure_float_or_none)
105106
converter.register_structure_hook(Optional[float], structure_float_or_none)
106107

108+
# Register structure hooks for nested types
109+
converter.register_structure_hook(Role, lambda obj, _: Role(obj))
110+
converter.register_structure_hook(Message, lambda obj, _: Message(**obj))
111+
converter.register_structure_hook(LLMInputs, lambda obj, _: LLMInputs(**obj))
112+
converter.register_structure_hook(EvaluationResult, lambda obj, _: EvaluationResult(**obj))
113+
converter.register_structure_hook(TraceLogImage, lambda obj, _: TraceLogImage(**obj))
114+
converter.register_structure_hook(TraceLogCommentSchema, lambda obj, _: TraceLogCommentSchema(**obj))
115+
converter.register_structure_hook(TraceLogAnnotationSchema, lambda obj, _: TraceLogAnnotationSchema(**obj))
116+
117+
def structure_model_params(obj, _):
118+
valid_params = {k: v for k, v in obj.items() if k in fields_dict(ModelParams)}
119+
return ModelParams(**valid_params)
120+
121+
converter.register_structure_hook(ModelParams, structure_model_params)
122+
123+
def structure_llm_inputs(obj, _):
124+
if obj is None:
125+
return None
126+
kwargs = {}
127+
for key, value in obj.items():
128+
if key == "messages":
129+
kwargs[key] = [converter.structure(msg, Message) for msg in value]
130+
elif key == "model_params":
131+
kwargs[key] = converter.structure(value, ModelParams)
132+
else:
133+
kwargs[key] = value
134+
return LLMInputs(**kwargs)
135+
136+
converter.register_structure_hook(LLMInputs, structure_llm_inputs)
137+
107138
def structure_trace_log_tree(data, _):
108139
kwargs = {}
109140
for key, value in data.items():
110141
if key == "children_logs":
111142
kwargs["children_logs"] = [structure_trace_log_tree(child, TraceLogTree) for child in value]
143+
elif key == "configuration":
144+
kwargs["configuration"] = converter.structure(value, LLMInputs)
145+
elif key == "scores":
146+
kwargs["scores"] = [converter.structure(score, EvaluationResult) for score in value]
147+
elif key == "images":
148+
kwargs["images"] = [converter.structure(image, TraceLogImage) for image in value]
149+
elif key == "comments":
150+
kwargs["comments"] = [converter.structure(comment, TraceLogCommentSchema) for comment in value]
151+
elif key == "annotations":
152+
kwargs["annotations"] = {int(k): {sk: converter.structure(sv, TraceLogAnnotationSchema) for sk, sv in v.items()} for k, v in value.items()}
112153
elif key in fields_dict(TraceLogTree):
113-
kwargs[key] = value
154+
field_type = fields_dict(TraceLogTree)[key].type
155+
kwargs[key] = converter.structure(value, field_type)
114156
return TraceLogTree(**kwargs)
115157

116158
converter.register_structure_hook(TraceLogTree, structure_trace_log_tree)

parea/schemas/log.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List, Optional, Union
22

3+
import json
34
import math
45
from enum import Enum
56

@@ -68,6 +69,36 @@ class Log:
6869
total_tokens: Optional[int] = 0
6970
cost: Optional[float] = 0.0
7071

72+
def convert_to_jsonl_row_for_finetuning(self) -> dict:
73+
"""Converts the trace log to a row in the finetuning jsonl format."""
74+
jsonl_row = {"messages": [m.to_dict() for m in self.configuration.messages]}
75+
output = self.output
76+
try:
77+
tool_calls = json.loads(output)
78+
tools = self.configuration.functions
79+
# if 'arguments' is in the output, it was actually a function call
80+
if "arguments" in tool_calls:
81+
function_call = tool_calls[0] if isinstance(tool_calls, List) else tool_calls
82+
function_call["arguments"] = json.dumps(function_call["arguments"])
83+
assistant_response = {
84+
"role": "assistant",
85+
"function_call": function_call,
86+
}
87+
jsonl_row["functions"] = tools
88+
else:
89+
tool_calls = tool_calls if isinstance(tool_calls, List) else [tool_calls]
90+
for tool_call in tool_calls:
91+
tool_call["arguments"] = json.dumps(tool_call["arguments"])
92+
assistant_response = {
93+
"role": "assistant",
94+
"tool_calls": tool_calls,
95+
}
96+
jsonl_row["tools"] = [{"type": "function", "function": tool} for tool in tools]
97+
except json.JSONDecodeError:
98+
assistant_response = {"role": "assistant", "content": output}
99+
jsonl_row["messages"].append(assistant_response)
100+
return jsonl_row
101+
71102

72103
@define
73104
class EvaluationResult:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
66
[tool.poetry]
77
name = "parea-ai"
88
packages = [{ include = "parea" }]
9-
version = "0.2.203"
9+
version = "0.2.204"
1010
description = "Parea python sdk"
1111
readme = "README.md"
1212
authors = ["joel-parea-ai <[email protected]>"]

0 commit comments

Comments
 (0)