Skip to content

Commit 3e8a120

Browse files
authored
RowTests for TextEvals (#1691)
* RowTests for TextEvals Fix row tests charts * fix lint * fix tests
1 parent 79ec106 commit 3e8a120

File tree

4 files changed

+9
-3
lines changed

4 files changed

+9
-3
lines changed

src/evidently/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@
6666
from .regression import DummyRMSE
6767
from .regression import MeanError
6868
from .regression import R2Score
69+
from .row_test_summary import RowTestSummary
6970

7071
__all__ = [
7172
"GroupBy",
73+
"RowTestSummary",
7274
# column statistics metrics
7375
"CategoryCount",
7476
"ValueDrift",

src/evidently/metrics/row_test_summary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def _render_test_column_widget(
4545
assert isinstance(result, SingleValue)
4646
success_count = int(result.value * row_count)
4747
if result.value == 1:
48-
return pie_chart(title=title, data=(["PASSED"], [success_count]), colors=["GREEN"], size=size)
48+
return pie_chart(title=title, data=(["PASSED"], [row_count]), colors=["GREEN"], size=size)
4949
if result.value == 0:
50-
return pie_chart(title=title, data=(["FAILED"], [0]), colors=["RED"], size=size)
50+
return pie_chart(title=title, data=(["FAILED"], [row_count]), colors=["RED"], size=size)
5151
return pie_chart(
5252
title=title,
5353
data=(["PASSED", "FAILED"], [success_count, row_count - success_count]),

src/evidently/presets/dataset_stats.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from evidently.metrics.dataset_statistics import DuplicatedColumnsCount
4747
from evidently.metrics.dataset_statistics import EmptyColumnsCount
4848
from evidently.metrics.dataset_statistics import EmptyRowsCount
49+
from evidently.metrics.row_test_summary import RowTestSummary
4950

5051

5152
class ValueStats(ColumnMetricContainer):
@@ -468,10 +469,11 @@ def get_value_stats(self, context: Context) -> List[ValueStats]:
468469
include_tests=self.include_tests,
469470
)
470471
for column in cols
472+
if column not in (context.data_definition.test_descriptors or [])
471473
]
472474

473475
def generate_metrics(self, context: Context) -> Sequence[MetricOrContainer]:
474-
metrics: List[MetricOrContainer] = [RowCount(tests=self._get_tests(self.row_count_tests))]
476+
metrics: List[MetricOrContainer] = [RowTestSummary(), RowCount(tests=self._get_tests(self.row_count_tests))]
475477
value_stats = self.get_value_stats(context)
476478
metrics.extend(list(chain(*[vs.metrics(context)[1:] for vs in value_stats])))
477479
return metrics

tests/future/presets/test_test_fields.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def value_stats_tests_check(metric: Metric, tests: Dict[str, ValueStatsTests]):
110110
return metric.tests == getattr(ts, field_name)
111111
if isinstance(metric, QuantileValue):
112112
return metric.tests == getattr(ts, f"q{int(metric.quantile * 100)}_tests")
113+
if isinstance(metric, RowTestSummary):
114+
return True
113115
raise ValueError(f"Unknown metric type {metric.__class__.__name__}")
114116

115117

0 commit comments

Comments
 (0)