Skip to content

Commit b2dcdb7

Browse files
authored
Fix RowTestSummary tests. (#1692)
1 parent 3e8a120 commit b2dcdb7

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

src/evidently/metrics/row_test_summary.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,25 @@ class RowTestSummary(MetricContainer):
2323

2424
def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:
2525
test_columns = self.get_test_columns(context)
26-
return [self.get_test_column_metric(tc) for tc in test_columns] + [self.get_row_count_metric()]
26+
return [self.get_test_column_metric(tc, context) for tc in test_columns] + [self.get_row_count_metric()]
2727

2828
def get_row_count_metric(self):
2929
return RowCount()
3030

3131
def get_test_columns(self, context):
3232
return self.columns or context.data_definition.test_descriptors or []
3333

34-
def get_test_column_metric(self, test_column: str) -> Metric:
35-
return MeanValue(
36-
column=test_column, tests=[gte(self.min_success_rate, alias=f"Share of passed '{test_column}' row tests")]
37-
)
34+
def get_test_column_metric(self, test_column: str, context) -> Metric:
35+
tests = None
36+
if context.configuration.include_tests:
37+
tests = [gte(self.min_success_rate, alias=f"Share of passed '{test_column}' row tests")]
38+
return MeanValue(column=test_column, tests=tests)
3839

3940
def _render_test_column_widget(
4041
self, context: "Context", test_column: str, row_count: int, size: WidgetSize
4142
) -> BaseWidgetInfo:
4243
title = f"Row test '{test_column}'"
43-
metric = self.get_test_column_metric(test_column)
44+
metric = self.get_test_column_metric(test_column, context)
4445
result = context.get_metric_result(metric)
4546
assert isinstance(result, SingleValue)
4647
success_count = int(result.value * row_count)

src/evidently/presets/dataset_stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def render(
484484
child_widgets: Optional[List[Tuple[Optional[MetricId], List[BaseWidgetInfo]]]] = None,
485485
) -> List[BaseWidgetInfo]:
486486
value_stats = self.get_value_stats(context)
487-
return list(chain(*[vs.render(context) for vs in value_stats]))
487+
return list(chain(*([RowTestSummary().render(context)] + [vs.render(context) for vs in value_stats])))
488488

489489

490490
class DataSummaryPreset(MetricContainer):

0 commit comments

Comments
 (0)