Skip to content

Commit 182f994

Browse files
janezdmarkotoplak
authored andcommitted
Merge pull request biolab#5763 from PrimozGodec/fix-groupby
[FIX] Group by: compute mode when all values in group nan
1 parent c08c34b commit 182f994

File tree

2 files changed

+67
-4
lines changed

2 files changed

+67
-4
lines changed

Orange/widgets/data/owgroupby.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Dict, List, Optional, Set
66

77
import pandas as pd
8+
from numpy import nan
89
from AnyQt.QtCore import (
910
QAbstractTableModel,
1011
QEvent,
@@ -58,7 +59,7 @@ def concatenate(x):
5859
"Mean": Aggregation("mean", {ContinuousVariable, TimeVariable}),
5960
"Median": Aggregation("median", {ContinuousVariable, TimeVariable}),
6061
"Mode": Aggregation(
61-
lambda x: pd.Series.mode(x)[0], {ContinuousVariable, TimeVariable}
62+
lambda x: pd.Series.mode(x).get(0, nan), {ContinuousVariable, TimeVariable}
6263
),
6364
"Standard deviation": Aggregation("std", {ContinuousVariable, TimeVariable}),
6465
"Variance": Aggregation("var", {ContinuousVariable, TimeVariable}),
@@ -404,7 +405,7 @@ def __gb_changed(self) -> None:
404405
self.gb_attrs = [values[row.row()] for row in sorted(rows)]
405406
# everything cached in result should be recomputed on gb change
406407
self.result = Result()
407-
self.commit()
408+
self.commit.deferred()
408409

409410
def __aggregation_changed(self, agg: str) -> None:
410411
"""
@@ -420,7 +421,7 @@ def __aggregation_changed(self, agg: str) -> None:
420421
else:
421422
self.aggregations[attr].discard(agg)
422423
self.agg_table_model.update_aggregation(attr)
423-
self.commit()
424+
self.commit.deferred()
424425

425426
@Inputs.data
426427
def set_data(self, data: Table) -> None:
@@ -448,11 +449,12 @@ def set_data(self, data: Table) -> None:
448449
self.agg_table_model.set_domain(data.domain if data else None)
449450
self._set_gb_selection()
450451

451-
self.commit()
452+
self.commit.now()
452453

453454
#########################
454455
# Task connected methods
455456

457+
@gui.deferred
456458
def commit(self) -> None:
457459
self.Error.clear()
458460
self.Warning.clear()

Orange/widgets/data/tests/test_owgroupby.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from Orange.data import (
1313
Table,
1414
table_to_frame,
15+
Domain,
16+
ContinuousVariable,
1517
)
1618
from Orange.data.tests.test_aggregate import create_sample_data
1719
from Orange.widgets.data.owgroupby import OWGroupBy
@@ -689,6 +691,65 @@ def test_time_variable(self):
689691
output = self.get_output(self.widget.Outputs.data)
690692
self.assertEqual(2, len(output))
691693

694+
def test_only_nan_in_group(self):
695+
data = Table(
696+
Domain([ContinuousVariable("A"), ContinuousVariable("B")]),
697+
np.array([[1, np.nan], [2, 1], [1, np.nan], [2, 1]]),
698+
)
699+
self.send_signal(self.widget.Inputs.data, data)
700+
701+
# select feature A as group-by
702+
self._set_selection(self.widget.gb_attrs_view, [0])
703+
# select all aggregations for feature B
704+
self.select_table_rows(self.widget.agg_table_view, [1])
705+
for cb in self.widget.agg_checkboxes.values():
706+
while not cb.isChecked():
707+
cb.click()
708+
709+
# unselect all aggregations for attr A
710+
self.select_table_rows(self.widget.agg_table_view, [0])
711+
for cb in self.widget.agg_checkboxes.values():
712+
while cb.isChecked():
713+
cb.click()
714+
715+
expected_columns = [
716+
"B - Mean",
717+
"B - Median",
718+
"B - Mode",
719+
"B - Standard deviation",
720+
"B - Variance",
721+
"B - Sum",
722+
"B - Min. value",
723+
"B - Max. value",
724+
"B - Span",
725+
"B - First value",
726+
"B - Last value",
727+
"B - Random value",
728+
"B - Count defined",
729+
"B - Count",
730+
"B - Proportion defined",
731+
"B - Concatenate",
732+
"A",
733+
]
734+
n = np.nan
735+
expected_df = pd.DataFrame(
736+
[
737+
[n, n, n, n, n, 0, n, n, n, n, n, n, 0, 2, 0, "", 1],
738+
[1, 1, 1, 0, 0, 2, 1, 1, 0, 1, 1, 1, 2, 2, 1, "1.0 1.0", 2],
739+
],
740+
columns=expected_columns,
741+
)
742+
output_df = table_to_frame(
743+
self.get_output(self.widget.Outputs.data), include_metas=True
744+
)
745+
pd.testing.assert_frame_equal(
746+
output_df,
747+
expected_df,
748+
check_dtype=False,
749+
check_column_type=False,
750+
check_categorical=False,
751+
)
752+
692753

693754
if __name__ == "__main__":
694755
unittest.main()

0 commit comments

Comments
 (0)