44from collections .abc import Mapping
55from dataclasses import dataclass
66from io import StringIO
7- from typing import Any , Callable , Generic , Literal , Protocol
7+ from typing import Any , Callable , Generic , Literal , Protocol , cast
88
99from pydantic import BaseModel , TypeAdapter
1010from rich .console import Console
@@ -168,6 +168,7 @@ def print(
168168 self ,
169169 width : int | None = None ,
170170 baseline : EvaluationReport [InputsT , OutputT , MetadataT ] | None = None ,
171+ * ,
171172 include_input : bool = False ,
172173 include_metadata : bool = False ,
173174 include_expected_output : bool = False ,
@@ -183,6 +184,7 @@ def print(
183184 label_configs : dict [str , RenderValueConfig ] | None = None ,
184185 metric_configs : dict [str , RenderNumberConfig ] | None = None ,
185186 duration_config : RenderNumberConfig | None = None ,
187+ include_reasons : bool = False ,
186188 ): # pragma: no cover
187189 """Print this report to the console, optionally comparing it to a baseline report.
188190
@@ -205,12 +207,14 @@ def print(
205207 label_configs = label_configs ,
206208 metric_configs = metric_configs ,
207209 duration_config = duration_config ,
210+ include_reasons = include_reasons ,
208211 )
209212 Console (width = width ).print (table )
210213
211214 def console_table (
212215 self ,
213216 baseline : EvaluationReport [InputsT , OutputT , MetadataT ] | None = None ,
217+ * ,
214218 include_input : bool = False ,
215219 include_metadata : bool = False ,
216220 include_expected_output : bool = False ,
@@ -226,6 +230,7 @@ def console_table(
226230 label_configs : dict [str , RenderValueConfig ] | None = None ,
227231 metric_configs : dict [str , RenderNumberConfig ] | None = None ,
228232 duration_config : RenderNumberConfig | None = None ,
233+ include_reasons : bool = False ,
229234 ) -> Table :
230235 """Return a table containing the data from this report, or the diff between this report and a baseline report.
231236
@@ -247,6 +252,7 @@ def console_table(
247252 label_configs = label_configs or {},
248253 metric_configs = metric_configs or {},
249254 duration_config = duration_config or _DEFAULT_DURATION_CONFIG ,
255+ include_reasons = include_reasons ,
250256 )
251257 if baseline is None :
252258 return renderer .build_table (self )
@@ -529,15 +535,16 @@ class ReportCaseRenderer:
529535 include_labels : bool
530536 include_metrics : bool
531537 include_assertions : bool
538+ include_reasons : bool
532539 include_durations : bool
533540 include_total_duration : bool
534541
535542 input_renderer : _ValueRenderer
536543 metadata_renderer : _ValueRenderer
537544 output_renderer : _ValueRenderer
538- score_renderers : dict [str , _NumberRenderer ]
539- label_renderers : dict [str , _ValueRenderer ]
540- metric_renderers : dict [str , _NumberRenderer ]
545+ score_renderers : Mapping [str , _NumberRenderer ]
546+ label_renderers : Mapping [str , _ValueRenderer ]
547+ metric_renderers : Mapping [str , _NumberRenderer ]
541548 duration_renderer : _NumberRenderer
542549
543550 def build_base_table (self , title : str ) -> Table :
@@ -581,10 +588,10 @@ def build_row(self, case: ReportCase) -> list[str]:
581588 row .append (self .output_renderer .render_value (None , case .output ) or EMPTY_CELL_STR )
582589
583590 if self .include_scores :
584- row .append (self ._render_dict ({k : v . value for k , v in case .scores .items ()}, self .score_renderers ))
591+ row .append (self ._render_dict ({k : v for k , v in case .scores .items ()}, self .score_renderers ))
585592
586593 if self .include_labels :
587- row .append (self ._render_dict ({k : v . value for k , v in case .labels .items ()}, self .label_renderers ))
594+ row .append (self ._render_dict ({k : v for k , v in case .labels .items ()}, self .label_renderers ))
588595
589596 if self .include_metrics :
590597 row .append (self ._render_dict (case .metrics , self .metric_renderers ))
@@ -783,26 +790,36 @@ def _render_dicts_diff(
783790 diff_lines .append (rendered )
784791 return '\n ' .join (diff_lines ) if diff_lines else EMPTY_CELL_STR
785792
786- @staticmethod
787793 def _render_dict (
788- case_dict : dict [str , T ],
794+ self ,
795+ case_dict : Mapping [str , EvaluationResult [T ] | T ],
789796 renderers : Mapping [str , _AbstractRenderer [T ]],
790797 * ,
791798 include_names : bool = True ,
792799 ) -> str :
793800 diff_lines : list [str ] = []
794801 for key , val in case_dict .items ():
795- rendered = renderers [key ].render_value (key if include_names else None , val )
802+ value = cast (EvaluationResult [T ], val ).value if isinstance (val , EvaluationResult ) else val
803+ rendered = renderers [key ].render_value (key if include_names else None , value )
804+ if self .include_reasons and isinstance (val , EvaluationResult ) and (reason := val .reason ):
805+ rendered += f'\n Reason: { reason } \n '
796806 diff_lines .append (rendered )
797807 return '\n ' .join (diff_lines ) if diff_lines else EMPTY_CELL_STR
798808
799- @staticmethod
800809 def _render_assertions (
810+ self ,
801811 assertions : list [EvaluationResult [bool ]],
802812 ) -> str :
803813 if not assertions :
804814 return EMPTY_CELL_STR
805- return '' .join (['[green]✔[/]' if a .value else '[red]✗[/]' for a in assertions ])
815+ lines : list [str ] = []
816+ for a in assertions :
817+ line = '[green]✔[/]' if a .value else '[red]✗[/]'
818+ if self .include_reasons :
819+ line = f'{ a .name } : { line } \n '
820+ line = f'{ line } Reason: { a .reason } \n \n ' if a .reason else line
821+ lines .append (line )
822+ return '' .join (lines )
806823
807824 @staticmethod
808825 def _render_aggregate_assertions (
@@ -863,6 +880,10 @@ class EvaluationRenderer:
863880 metric_configs : dict [str , RenderNumberConfig ]
864881 duration_config : RenderNumberConfig
865882
883+ # TODO: Make this class kw-only so we can reorder the kwargs
884+ # Data to include
885+ include_reasons : bool # only applies to reports, not to diffs
886+
866887 def include_scores (self , report : EvaluationReport , baseline : EvaluationReport | None = None ):
867888 return any (case .scores for case in self ._all_cases (report , baseline ))
868889
@@ -909,6 +930,7 @@ def _get_case_renderer(
909930 include_labels = self .include_labels (report , baseline ),
910931 include_metrics = self .include_metrics (report , baseline ),
911932 include_assertions = self .include_assertions (report , baseline ),
933+ include_reasons = self .include_reasons ,
912934 include_durations = self .include_durations ,
913935 include_total_duration = self .include_total_duration ,
914936 input_renderer = input_renderer ,
0 commit comments