4
4
from collections .abc import Mapping
5
5
from dataclasses import dataclass
6
6
from io import StringIO
7
- from typing import Any , Callable , Generic , Literal , Protocol
7
+ from typing import Any , Callable , Generic , Literal , Protocol , cast
8
8
9
9
from pydantic import BaseModel , TypeAdapter
10
10
from rich .console import Console
@@ -168,6 +168,7 @@ def print(
168
168
self ,
169
169
width : int | None = None ,
170
170
baseline : EvaluationReport [InputsT , OutputT , MetadataT ] | None = None ,
171
+ * ,
171
172
include_input : bool = False ,
172
173
include_metadata : bool = False ,
173
174
include_expected_output : bool = False ,
@@ -183,6 +184,7 @@ def print(
183
184
label_configs : dict [str , RenderValueConfig ] | None = None ,
184
185
metric_configs : dict [str , RenderNumberConfig ] | None = None ,
185
186
duration_config : RenderNumberConfig | None = None ,
187
+ include_reasons : bool = False ,
186
188
): # pragma: no cover
187
189
"""Print this report to the console, optionally comparing it to a baseline report.
188
190
@@ -205,12 +207,14 @@ def print(
205
207
label_configs = label_configs ,
206
208
metric_configs = metric_configs ,
207
209
duration_config = duration_config ,
210
+ include_reasons = include_reasons ,
208
211
)
209
212
Console (width = width ).print (table )
210
213
211
214
def console_table (
212
215
self ,
213
216
baseline : EvaluationReport [InputsT , OutputT , MetadataT ] | None = None ,
217
+ * ,
214
218
include_input : bool = False ,
215
219
include_metadata : bool = False ,
216
220
include_expected_output : bool = False ,
@@ -226,6 +230,7 @@ def console_table(
226
230
label_configs : dict [str , RenderValueConfig ] | None = None ,
227
231
metric_configs : dict [str , RenderNumberConfig ] | None = None ,
228
232
duration_config : RenderNumberConfig | None = None ,
233
+ include_reasons : bool = False ,
229
234
) -> Table :
230
235
"""Return a table containing the data from this report, or the diff between this report and a baseline report.
231
236
@@ -247,6 +252,7 @@ def console_table(
247
252
label_configs = label_configs or {},
248
253
metric_configs = metric_configs or {},
249
254
duration_config = duration_config or _DEFAULT_DURATION_CONFIG ,
255
+ include_reasons = include_reasons ,
250
256
)
251
257
if baseline is None :
252
258
return renderer .build_table (self )
@@ -529,15 +535,16 @@ class ReportCaseRenderer:
529
535
include_labels : bool
530
536
include_metrics : bool
531
537
include_assertions : bool
538
+ include_reasons : bool
532
539
include_durations : bool
533
540
include_total_duration : bool
534
541
535
542
input_renderer : _ValueRenderer
536
543
metadata_renderer : _ValueRenderer
537
544
output_renderer : _ValueRenderer
538
- score_renderers : dict [str , _NumberRenderer ]
539
- label_renderers : dict [str , _ValueRenderer ]
540
- metric_renderers : dict [str , _NumberRenderer ]
545
+ score_renderers : Mapping [str , _NumberRenderer ]
546
+ label_renderers : Mapping [str , _ValueRenderer ]
547
+ metric_renderers : Mapping [str , _NumberRenderer ]
541
548
duration_renderer : _NumberRenderer
542
549
543
550
def build_base_table (self , title : str ) -> Table :
@@ -581,10 +588,10 @@ def build_row(self, case: ReportCase) -> list[str]:
581
588
row .append (self .output_renderer .render_value (None , case .output ) or EMPTY_CELL_STR )
582
589
583
590
if self .include_scores :
584
- row .append (self ._render_dict ({k : v . value for k , v in case .scores .items ()}, self .score_renderers ))
591
+ row .append (self ._render_dict ({k : v for k , v in case .scores .items ()}, self .score_renderers ))
585
592
586
593
if self .include_labels :
587
- row .append (self ._render_dict ({k : v . value for k , v in case .labels .items ()}, self .label_renderers ))
594
+ row .append (self ._render_dict ({k : v for k , v in case .labels .items ()}, self .label_renderers ))
588
595
589
596
if self .include_metrics :
590
597
row .append (self ._render_dict (case .metrics , self .metric_renderers ))
@@ -783,26 +790,36 @@ def _render_dicts_diff(
783
790
diff_lines .append (rendered )
784
791
return '\n ' .join (diff_lines ) if diff_lines else EMPTY_CELL_STR
785
792
786
- @staticmethod
787
793
def _render_dict (
788
- case_dict : dict [str , T ],
794
+ self ,
795
+ case_dict : Mapping [str , EvaluationResult [T ] | T ],
789
796
renderers : Mapping [str , _AbstractRenderer [T ]],
790
797
* ,
791
798
include_names : bool = True ,
792
799
) -> str :
793
800
diff_lines : list [str ] = []
794
801
for key , val in case_dict .items ():
795
- rendered = renderers [key ].render_value (key if include_names else None , val )
802
+ value = cast (EvaluationResult [T ], val ).value if isinstance (val , EvaluationResult ) else val
803
+ rendered = renderers [key ].render_value (key if include_names else None , value )
804
+ if self .include_reasons and isinstance (val , EvaluationResult ) and (reason := val .reason ):
805
+ rendered += f'\n Reason: { reason } \n '
796
806
diff_lines .append (rendered )
797
807
return '\n ' .join (diff_lines ) if diff_lines else EMPTY_CELL_STR
798
808
799
- @staticmethod
800
809
def _render_assertions (
810
+ self ,
801
811
assertions : list [EvaluationResult [bool ]],
802
812
) -> str :
803
813
if not assertions :
804
814
return EMPTY_CELL_STR
805
- return '' .join (['[green]✔[/]' if a .value else '[red]✗[/]' for a in assertions ])
815
+ lines : list [str ] = []
816
+ for a in assertions :
817
+ line = '[green]✔[/]' if a .value else '[red]✗[/]'
818
+ if self .include_reasons :
819
+ line = f'{ a .name } : { line } \n '
820
+ line = f'{ line } Reason: { a .reason } \n \n ' if a .reason else line
821
+ lines .append (line )
822
+ return '' .join (lines )
806
823
807
824
@staticmethod
808
825
def _render_aggregate_assertions (
@@ -863,6 +880,10 @@ class EvaluationRenderer:
863
880
metric_configs : dict [str , RenderNumberConfig ]
864
881
duration_config : RenderNumberConfig
865
882
883
+ # TODO: Make this class kw-only so we can reorder the kwargs
884
+ # Data to include
885
+ include_reasons : bool # only applies to reports, not to diffs
886
+
866
887
def include_scores (self , report : EvaluationReport , baseline : EvaluationReport | None = None ):
867
888
return any (case .scores for case in self ._all_cases (report , baseline ))
868
889
@@ -909,6 +930,7 @@ def _get_case_renderer(
909
930
include_labels = self .include_labels (report , baseline ),
910
931
include_metrics = self .include_metrics (report , baseline ),
911
932
include_assertions = self .include_assertions (report , baseline ),
933
+ include_reasons = self .include_reasons ,
912
934
include_durations = self .include_durations ,
913
935
include_total_duration = self .include_total_duration ,
914
936
input_renderer = input_renderer ,
0 commit comments