@@ -158,12 +158,15 @@ def print(
158
158
width : int | None = None ,
159
159
baseline : EvaluationReport | None = None ,
160
160
include_input : bool = False ,
161
+ include_metadata : bool = False ,
162
+ include_expected_output : bool = False ,
161
163
include_output : bool = False ,
162
164
include_durations : bool = True ,
163
165
include_total_duration : bool = False ,
164
166
include_removed_cases : bool = False ,
165
167
include_averages : bool = True ,
166
168
input_config : RenderValueConfig | None = None ,
169
+ metadata_config : RenderValueConfig | None = None ,
167
170
output_config : RenderValueConfig | None = None ,
168
171
score_configs : dict [str , RenderNumberConfig ] | None = None ,
169
172
label_configs : dict [str , RenderValueConfig ] | None = None ,
@@ -177,12 +180,15 @@ def print(
177
180
table = self .console_table (
178
181
baseline = baseline ,
179
182
include_input = include_input ,
183
+ include_metadata = include_metadata ,
184
+ include_expected_output = include_expected_output ,
180
185
include_output = include_output ,
181
186
include_durations = include_durations ,
182
187
include_total_duration = include_total_duration ,
183
188
include_removed_cases = include_removed_cases ,
184
189
include_averages = include_averages ,
185
190
input_config = input_config ,
191
+ metadata_config = metadata_config ,
186
192
output_config = output_config ,
187
193
score_configs = score_configs ,
188
194
label_configs = label_configs ,
@@ -195,12 +201,15 @@ def console_table(
195
201
self ,
196
202
baseline : EvaluationReport | None = None ,
197
203
include_input : bool = False ,
204
+ include_metadata : bool = False ,
205
+ include_expected_output : bool = False ,
198
206
include_output : bool = False ,
199
207
include_durations : bool = True ,
200
208
include_total_duration : bool = False ,
201
209
include_removed_cases : bool = False ,
202
210
include_averages : bool = True ,
203
211
input_config : RenderValueConfig | None = None ,
212
+ metadata_config : RenderValueConfig | None = None ,
204
213
output_config : RenderValueConfig | None = None ,
205
214
score_configs : dict [str , RenderNumberConfig ] | None = None ,
206
215
label_configs : dict [str , RenderValueConfig ] | None = None ,
@@ -213,12 +222,15 @@ def console_table(
213
222
"""
214
223
renderer = EvaluationRenderer (
215
224
include_input = include_input ,
225
+ include_metadata = include_metadata ,
226
+ include_expected_output = include_expected_output ,
216
227
include_output = include_output ,
217
228
include_durations = include_durations ,
218
229
include_total_duration = include_total_duration ,
219
230
include_removed_cases = include_removed_cases ,
220
231
include_averages = include_averages ,
221
232
input_config = {** _DEFAULT_VALUE_CONFIG , ** (input_config or {})},
233
+ metadata_config = {** _DEFAULT_VALUE_CONFIG , ** (metadata_config or {})},
222
234
output_config = output_config or _DEFAULT_VALUE_CONFIG ,
223
235
score_configs = score_configs or {},
224
236
label_configs = label_configs or {},
@@ -496,6 +508,8 @@ def render_diff(self, name: str | None, old: T_contra | None, new: T_contra | No
496
508
@dataclass
497
509
class ReportCaseRenderer :
498
510
include_input : bool
511
+ include_metadata : bool
512
+ include_expected_output : bool
499
513
include_output : bool
500
514
include_scores : bool
501
515
include_labels : bool
@@ -505,6 +519,7 @@ class ReportCaseRenderer:
505
519
include_total_duration : bool
506
520
507
521
input_renderer : _ValueRenderer
522
+ metadata_renderer : _ValueRenderer
508
523
output_renderer : _ValueRenderer
509
524
score_renderers : dict [str , _NumberRenderer ]
510
525
label_renderers : dict [str , _ValueRenderer ]
@@ -517,6 +532,10 @@ def build_base_table(self, title: str) -> Table:
517
532
table .add_column ('Case ID' , style = 'bold' )
518
533
if self .include_input :
519
534
table .add_column ('Inputs' , overflow = 'fold' )
535
+ if self .include_metadata :
536
+ table .add_column ('Metadata' , overflow = 'fold' )
537
+ if self .include_expected_output :
538
+ table .add_column ('Expected Output' , overflow = 'fold' )
520
539
if self .include_output :
521
540
table .add_column ('Outputs' , overflow = 'fold' )
522
541
if self .include_scores :
@@ -538,6 +557,12 @@ def build_row(self, case: ReportCase) -> list[str]:
538
557
if self .include_input :
539
558
row .append (self .input_renderer .render_value (None , case .inputs ) or EMPTY_CELL_STR )
540
559
560
+ if self .include_metadata :
561
+ row .append (self .input_renderer .render_value (None , case .metadata ) or EMPTY_CELL_STR )
562
+
563
+ if self .include_expected_output :
564
+ row .append (self .input_renderer .render_value (None , case .expected_output ) or EMPTY_CELL_STR )
565
+
541
566
if self .include_output :
542
567
row .append (self .output_renderer .render_value (None , case .output ) or EMPTY_CELL_STR )
543
568
@@ -565,6 +590,12 @@ def build_aggregate_row(self, aggregate: ReportCaseAggregate) -> list[str]:
565
590
if self .include_input :
566
591
row .append (EMPTY_AGGREGATE_CELL_STR )
567
592
593
+ if self .include_metadata :
594
+ row .append (EMPTY_AGGREGATE_CELL_STR )
595
+
596
+ if self .include_expected_output :
597
+ row .append (EMPTY_AGGREGATE_CELL_STR )
598
+
568
599
if self .include_output :
569
600
row .append (EMPTY_AGGREGATE_CELL_STR )
570
601
@@ -598,6 +629,19 @@ def build_diff_row(
598
629
input_diff = self .input_renderer .render_diff (None , baseline .inputs , new_case .inputs ) or EMPTY_CELL_STR
599
630
row .append (input_diff )
600
631
632
+ if self .include_metadata :
633
+ metadata_diff = (
634
+ self .metadata_renderer .render_diff (None , baseline .metadata , new_case .metadata ) or EMPTY_CELL_STR
635
+ )
636
+ row .append (metadata_diff )
637
+
638
+ if self .include_expected_output :
639
+ expected_output_diff = (
640
+ self .output_renderer .render_diff (None , baseline .expected_output , new_case .expected_output )
641
+ or EMPTY_CELL_STR
642
+ )
643
+ row .append (expected_output_diff )
644
+
601
645
if self .include_output :
602
646
output_diff = self .output_renderer .render_diff (None , baseline .output , new_case .output ) or EMPTY_CELL_STR
603
647
row .append (output_diff )
@@ -642,6 +686,12 @@ def build_diff_aggregate_row(
642
686
if self .include_input :
643
687
row .append (EMPTY_AGGREGATE_CELL_STR )
644
688
689
+ if self .include_metadata :
690
+ row .append (EMPTY_AGGREGATE_CELL_STR )
691
+
692
+ if self .include_expected_output :
693
+ row .append (EMPTY_AGGREGATE_CELL_STR )
694
+
645
695
if self .include_output :
646
696
row .append (EMPTY_AGGREGATE_CELL_STR )
647
697
@@ -777,6 +827,8 @@ class EvaluationRenderer:
777
827
778
828
# Columns to include
779
829
include_input : bool
830
+ include_metadata : bool
831
+ include_expected_output : bool
780
832
include_output : bool
781
833
include_durations : bool
782
834
include_total_duration : bool
@@ -786,6 +838,7 @@ class EvaluationRenderer:
786
838
include_averages : bool
787
839
788
840
input_config : RenderValueConfig
841
+ metadata_config : RenderValueConfig
789
842
output_config : RenderValueConfig
790
843
score_configs : dict [str , RenderNumberConfig ]
791
844
label_configs : dict [str , RenderValueConfig ]
@@ -820,6 +873,7 @@ def _get_case_renderer(
820
873
self , report : EvaluationReport , baseline : EvaluationReport | None = None
821
874
) -> ReportCaseRenderer :
822
875
input_renderer = _ValueRenderer .from_config (self .input_config )
876
+ metadata_renderer = _ValueRenderer .from_config (self .metadata_config )
823
877
output_renderer = _ValueRenderer .from_config (self .output_config )
824
878
score_renderers = self ._infer_score_renderers (report , baseline )
825
879
label_renderers = self ._infer_label_renderers (report , baseline )
@@ -830,6 +884,8 @@ def _get_case_renderer(
830
884
831
885
return ReportCaseRenderer (
832
886
include_input = self .include_input ,
887
+ include_metadata = self .include_metadata ,
888
+ include_expected_output = self .include_expected_output ,
833
889
include_output = self .include_output ,
834
890
include_scores = self .include_scores (report , baseline ),
835
891
include_labels = self .include_labels (report , baseline ),
@@ -838,6 +894,7 @@ def _get_case_renderer(
838
894
include_durations = self .include_durations ,
839
895
include_total_duration = self .include_total_duration ,
840
896
input_renderer = input_renderer ,
897
+ metadata_renderer = metadata_renderer ,
841
898
output_renderer = output_renderer ,
842
899
score_renderers = score_renderers ,
843
900
label_renderers = label_renderers ,
0 commit comments