|
11 | 11 | from cattrs import GenConverter |
12 | 12 |
|
13 | 13 | from parea.constants import ADJECTIVES, NOUNS, TURN_OFF_PAREA_LOGGING |
14 | | -from parea.schemas.models import Completion, PaginatedTraceLogsResponse, TraceLog, TraceLogTree, UpdateLog |
| 14 | +from parea.schemas import EvaluationResult, LLMInputs, Message, ModelParams, Role |
| 15 | +from parea.schemas.models import Completion, PaginatedTraceLogsResponse, TraceLog, TraceLogAnnotationSchema, TraceLogCommentSchema, TraceLogImage, TraceLogTree, UpdateLog |
15 | 16 | from parea.utils.universal_encoder import json_dumps |
16 | 17 |
|
17 | 18 |
|
@@ -104,13 +105,54 @@ def structure_float_or_none(obj: Any, cl: type) -> Optional[float]: |
104 | 105 | converter.register_structure_hook(float, structure_float_or_none) |
105 | 106 | converter.register_structure_hook(Optional[float], structure_float_or_none) |
106 | 107 |
|
| 108 | + # Register structure hooks for nested types |
| 109 | + converter.register_structure_hook(Role, lambda obj, _: Role(obj)) |
| 110 | + converter.register_structure_hook(Message, lambda obj, _: Message(**obj)) |
| 111 | + converter.register_structure_hook(LLMInputs, lambda obj, _: LLMInputs(**obj)) |
| 112 | + converter.register_structure_hook(EvaluationResult, lambda obj, _: EvaluationResult(**obj)) |
| 113 | + converter.register_structure_hook(TraceLogImage, lambda obj, _: TraceLogImage(**obj)) |
| 114 | + converter.register_structure_hook(TraceLogCommentSchema, lambda obj, _: TraceLogCommentSchema(**obj)) |
| 115 | + converter.register_structure_hook(TraceLogAnnotationSchema, lambda obj, _: TraceLogAnnotationSchema(**obj)) |
| 116 | + |
| 117 | + def structure_model_params(obj, _): |
| 118 | + valid_params = {k: v for k, v in obj.items() if k in fields_dict(ModelParams)} |
| 119 | + return ModelParams(**valid_params) |
| 120 | + |
| 121 | + converter.register_structure_hook(ModelParams, structure_model_params) |
| 122 | + |
| 123 | + def structure_llm_inputs(obj, _): |
| 124 | + if obj is None: |
| 125 | + return None |
| 126 | + kwargs = {} |
| 127 | + for key, value in obj.items(): |
| 128 | + if key == "messages": |
| 129 | + kwargs[key] = [converter.structure(msg, Message) for msg in value] |
| 130 | + elif key == "model_params": |
| 131 | + kwargs[key] = converter.structure(value, ModelParams) |
| 132 | + else: |
| 133 | + kwargs[key] = value |
| 134 | + return LLMInputs(**kwargs) |
| 135 | + |
| 136 | + converter.register_structure_hook(LLMInputs, structure_llm_inputs) |
| 137 | + |
107 | 138 | def structure_trace_log_tree(data, _): |
108 | 139 | kwargs = {} |
109 | 140 | for key, value in data.items(): |
110 | 141 | if key == "children_logs": |
111 | 142 | kwargs["children_logs"] = [structure_trace_log_tree(child, TraceLogTree) for child in value] |
| 143 | + elif key == "configuration": |
| 144 | + kwargs["configuration"] = converter.structure(value, LLMInputs) |
| 145 | + elif key == "scores": |
| 146 | + kwargs["scores"] = [converter.structure(score, EvaluationResult) for score in value] |
| 147 | + elif key == "images": |
| 148 | + kwargs["images"] = [converter.structure(image, TraceLogImage) for image in value] |
| 149 | + elif key == "comments": |
| 150 | + kwargs["comments"] = [converter.structure(comment, TraceLogCommentSchema) for comment in value] |
| 151 | + elif key == "annotations": |
| 152 | + kwargs["annotations"] = {int(k): {sk: converter.structure(sv, TraceLogAnnotationSchema) for sk, sv in v.items()} for k, v in value.items()} |
112 | 153 | elif key in fields_dict(TraceLogTree): |
113 | | - kwargs[key] = value |
| 154 | + field_type = fields_dict(TraceLogTree)[key].type |
| 155 | + kwargs[key] = converter.structure(value, field_type) |
114 | 156 | return TraceLogTree(**kwargs) |
115 | 157 |
|
116 | 158 | converter.register_structure_hook(TraceLogTree, structure_trace_log_tree) |
|
0 commit comments