@@ -105,26 +105,25 @@ def structure_float_or_none(obj: Any, cl: type) -> Optional[float]:
105105 converter .register_structure_hook (float , structure_float_or_none )
106106 converter .register_structure_hook (Optional [float ], structure_float_or_none )
107107
108+ # Helper function to filter valid fields
109+ def filter_valid_fields (obj , cls ):
110+ return {k : v for k , v in obj .items () if k in fields_dict (cls )}
111+
108112 # Register structure hooks for nested types
109113 converter .register_structure_hook (Role , lambda obj , _ : Role (obj ))
110- converter .register_structure_hook (Message , lambda obj , _ : Message (** obj ))
111- converter .register_structure_hook (LLMInputs , lambda obj , _ : LLMInputs (** obj ))
112- converter .register_structure_hook (EvaluationResult , lambda obj , _ : EvaluationResult (** obj ))
113- converter .register_structure_hook (TraceLogImage , lambda obj , _ : TraceLogImage (** obj ))
114- converter .register_structure_hook (TraceLogCommentSchema , lambda obj , _ : TraceLogCommentSchema (** obj ))
115- converter .register_structure_hook (TraceLogAnnotationSchema , lambda obj , _ : TraceLogAnnotationSchema (** obj ))
116-
117- def structure_model_params (obj , _ ):
118- valid_params = {k : v for k , v in obj .items () if k in fields_dict (ModelParams )}
119- return ModelParams (** valid_params )
120-
121- converter .register_structure_hook (ModelParams , structure_model_params )
114+ converter .register_structure_hook (Message , lambda obj , _ : Message (** filter_valid_fields (obj , Message )))
115+ converter .register_structure_hook (EvaluationResult , lambda obj , _ : EvaluationResult (** filter_valid_fields (obj , EvaluationResult )))
116+ converter .register_structure_hook (TraceLogImage , lambda obj , _ : TraceLogImage (** filter_valid_fields (obj , TraceLogImage )))
117+ converter .register_structure_hook (TraceLogCommentSchema , lambda obj , _ : TraceLogCommentSchema (** filter_valid_fields (obj , TraceLogCommentSchema )))
118+ converter .register_structure_hook (TraceLogAnnotationSchema , lambda obj , _ : TraceLogAnnotationSchema (** filter_valid_fields (obj , TraceLogAnnotationSchema )))
119+ converter .register_structure_hook (ModelParams , lambda obj , _ : ModelParams (** filter_valid_fields (obj , ModelParams )))
122120
123121 def structure_llm_inputs (obj , _ ):
124122 if obj is None :
125123 return None
124+ valid_fields = filter_valid_fields (obj , LLMInputs )
126125 kwargs = {}
127- for key , value in obj .items ():
126+ for key , value in valid_fields .items ():
128127 if key == "messages" :
129128 kwargs [key ] = [converter .structure (msg , Message ) for msg in value ]
130129 elif key == "model_params" :
@@ -136,8 +135,9 @@ def structure_llm_inputs(obj, _):
136135 converter .register_structure_hook (LLMInputs , structure_llm_inputs )
137136
138137 def structure_trace_log_tree (data , _ ):
138+ valid_fields = filter_valid_fields (data , TraceLogTree )
139139 kwargs = {}
140- for key , value in data .items ():
140+ for key , value in valid_fields .items ():
141141 if key == "children_logs" :
142142 kwargs ["children_logs" ] = [structure_trace_log_tree (child , TraceLogTree ) for child in value ]
143143 elif key == "configuration" :
@@ -150,7 +150,7 @@ def structure_trace_log_tree(data, _):
150150 kwargs ["comments" ] = [converter .structure (comment , TraceLogCommentSchema ) for comment in value ]
151151 elif key == "annotations" :
152152 kwargs ["annotations" ] = {int (k ): {sk : converter .structure (sv , TraceLogAnnotationSchema ) for sk , sv in v .items ()} for k , v in value .items ()}
153- elif key in fields_dict ( TraceLogTree ) :
153+ else :
154154 field_type = fields_dict (TraceLogTree )[key ].type
155155 kwargs [key ] = converter .structure (value , field_type )
156156 return TraceLogTree (** kwargs )
0 commit comments