Skip to content

Commit b35b050

Browse files
authored
Merge pull request #2176 from jerneju/string-filter-table
[FIX] Select Rows and Table: Filtering string values
2 parents e815184 + f379c0b commit b35b050

File tree

3 files changed

+204
-100
lines changed

3 files changed

+204
-100
lines changed

.travis/stage_after_success.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ if [ "$BUILD_DOCS" ] &&
99
fi
1010

1111
if [ "$UPLOAD_COVERAGE" ]; then
12+
cp $TRAVIS_BUILD_DIR/codecov.yml .
1213
codecov
1314
fi
1415

Orange/data/table.py

Lines changed: 164 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,10 +1133,26 @@ def _filter_same_value(self, column, value, negate=False):
11331133
sel = np.logical_not(sel)
11341134
return self.from_table_rows(self, sel)
11351135

1136-
def _filter_values_indicators(self, filter):
1137-
from Orange.data import filter as data_filter
1136+
def _filter_values(self, filter):
1137+
selection = self._values_filter_to_indicator(filter)
1138+
return self.from_table(self.domain, self, selection)
1139+
1140+
def _values_filter_to_indicator(self, filter):
1141+
"""Return selection of rows matching the filter conditions
1142+
1143+
Handles conjunction/disjunction and negate modifiers
1144+
1145+
Parameters
1146+
----------
1147+
filter: Values object containing the conditions
1148+
1149+
Returns
1150+
-------
1151+
A 1d bool array. len(result) == len(self)
1152+
"""
1153+
from Orange.data.filter import Values
11381154

1139-
if isinstance(filter, data_filter.Values):
1155+
if isinstance(filter, Values):
11401156
conditions = filter.conditions
11411157
conjunction = filter.conjunction
11421158
else:
@@ -1148,109 +1164,157 @@ def _filter_values_indicators(self, filter):
11481164
sel = np.zeros(len(self), dtype=bool)
11491165

11501166
for f in conditions:
1151-
if isinstance(f, data_filter.Values):
1152-
if conjunction:
1153-
sel *= self._filter_values_indicators(f)
1154-
else:
1155-
sel += self._filter_values_indicators(f)
1156-
continue
1157-
col = self.get_column_view(f.column)[0]
1158-
if isinstance(f, data_filter.FilterDiscrete) and f.values is None \
1159-
or isinstance(f, data_filter.FilterContinuous) and \
1160-
f.oper == f.IsDefined:
1161-
col = col.astype(float)
1162-
if conjunction:
1163-
sel *= ~np.isnan(col)
1164-
else:
1165-
sel += ~np.isnan(col)
1166-
elif isinstance(f, data_filter.FilterString) and \
1167-
f.oper == f.IsDefined:
1168-
if conjunction:
1169-
sel *= col.astype(bool)
1170-
else:
1171-
sel += col.astype(bool)
1172-
elif isinstance(f, data_filter.FilterDiscrete):
1173-
if conjunction:
1174-
s2 = np.zeros(len(self), dtype=bool)
1175-
for val in f.values:
1176-
if not isinstance(val, Real):
1177-
val = self.domain[f.column].to_val(val)
1178-
s2 += (col == val)
1179-
sel *= s2
1180-
else:
1181-
for val in f.values:
1182-
if not isinstance(val, Real):
1183-
val = self.domain[f.column].to_val(val)
1184-
sel += (col == val)
1185-
elif isinstance(f, data_filter.FilterStringList):
1186-
if not f.case_sensitive:
1187-
# noinspection PyTypeChecker
1188-
col = np.char.lower(np.array(col, dtype=str))
1189-
vals = [val.lower() for val in f.values]
1190-
else:
1191-
vals = f.values
1192-
if conjunction:
1193-
sel *= reduce(operator.add,
1194-
(col == val for val in vals))
1195-
else:
1196-
sel = reduce(operator.add,
1197-
(col == val for val in vals), sel)
1198-
elif isinstance(f, data_filter.FilterRegex):
1199-
sel = np.vectorize(f)(col)
1200-
elif isinstance(f, (data_filter.FilterContinuous,
1201-
data_filter.FilterString)):
1202-
if (isinstance(f, data_filter.FilterString) and
1203-
not f.case_sensitive):
1204-
# noinspection PyTypeChecker
1205-
col = np.char.lower(np.array(col, dtype=str))
1206-
fmin = f.min.lower()
1207-
if f.oper in [f.Between, f.Outside]:
1208-
fmax = f.max.lower()
1209-
else:
1210-
fmin, fmax = f.min, f.max
1211-
if f.oper == f.Equal:
1212-
col = (col == fmin)
1213-
elif f.oper == f.NotEqual:
1214-
col = (col != fmin)
1215-
elif f.oper == f.Less:
1216-
col = (col < fmin)
1217-
elif f.oper == f.LessEqual:
1218-
col = (col <= fmin)
1219-
elif f.oper == f.Greater:
1220-
col = (col > fmin)
1221-
elif f.oper == f.GreaterEqual:
1222-
col = (col >= fmin)
1223-
elif f.oper == f.Between:
1224-
col = (col >= fmin) * (col <= fmax)
1225-
elif f.oper == f.Outside:
1226-
col = (col < fmin) + (col > fmax)
1227-
elif not isinstance(f, data_filter.FilterString):
1228-
raise TypeError("Invalid operator")
1229-
elif f.oper == f.Contains:
1230-
col = np.fromiter((fmin in e for e in col),
1231-
dtype=bool)
1232-
elif f.oper == f.StartsWith:
1233-
col = np.fromiter((e.startswith(fmin) for e in col),
1234-
dtype=bool)
1235-
elif f.oper == f.EndsWith:
1236-
col = np.fromiter((e.endswith(fmin) for e in col),
1237-
dtype=bool)
1238-
else:
1239-
raise TypeError("Invalid operator")
1240-
if conjunction:
1241-
sel *= col
1242-
else:
1243-
sel += col
1167+
selection = self._filter_to_indicator(f)
1168+
1169+
if conjunction:
1170+
sel *= selection
12441171
else:
1245-
raise TypeError("Invalid filter")
1172+
sel += selection
12461173

12471174
if filter.negate:
12481175
sel = ~sel
12491176
return sel
12501177

1251-
def _filter_values(self, filter):
1252-
sel = self._filter_values_indicators(filter)
1253-
return self.from_table(self.domain, self, sel)
1178+
def _filter_to_indicator(self, filter):
1179+
"""Return selection of rows that match the condition.
1180+
1181+
Parameters
1182+
----------
1183+
filter: ValueFilter describing the condition
1184+
1185+
Returns
1186+
-------
1187+
A 1d bool array. len(result) == len(self)
1188+
"""
1189+
from Orange.data.filter import (
1190+
FilterContinuous, FilterDiscrete, FilterRegex, FilterString,
1191+
FilterStringList, Values
1192+
)
1193+
if isinstance(filter, Values):
1194+
return self._values_filter_to_indicator(filter)
1195+
1196+
col = self.get_column_view(filter.column)[0]
1197+
1198+
if isinstance(filter, FilterDiscrete):
1199+
return self._discrete_filter_to_indicator(filter, col)
1200+
1201+
if isinstance(filter, FilterContinuous):
1202+
return self._continuous_filter_to_indicator(filter, col)
1203+
1204+
if isinstance(filter, FilterString):
1205+
return self._string_filter_to_indicator(filter, col)
1206+
1207+
if isinstance(filter, FilterStringList):
1208+
if not filter.case_sensitive:
1209+
col = np.char.lower(np.array(col, dtype=str))
1210+
vals = [val.lower() for val in filter.values]
1211+
else:
1212+
vals = filter.values
1213+
return reduce(operator.add, (col == val for val in vals))
1214+
1215+
if isinstance(filter, FilterRegex):
1216+
return np.vectorize(filter)(col)
1217+
1218+
raise TypeError("Invalid filter")
1219+
1220+
def _discrete_filter_to_indicator(self, filter, col):
1221+
"""Return selection of rows matched by the given discrete filter.
1222+
1223+
Parameters
1224+
----------
1225+
filter: FilterDiscrete
1226+
col: np.ndarray
1227+
1228+
Returns
1229+
-------
1230+
A 1d bool array. len(result) == len(self)
1231+
"""
1232+
if filter.values is None: # <- is defined filter
1233+
col = col.astype(float)
1234+
return ~np.isnan(col)
1235+
1236+
sel = np.zeros(len(self), dtype=bool)
1237+
for val in filter.values:
1238+
if not isinstance(val, Real):
1239+
val = self.domain[filter.column].to_val(val)
1240+
sel += (col == val)
1241+
return sel
1242+
1243+
def _continuous_filter_to_indicator(self, filter, col):
1244+
"""Return selection of rows matched by the given continuous filter.
1245+
1246+
Parameters
1247+
----------
1248+
filter: FilterContinuous
1249+
col: np.ndarray
1250+
1251+
Returns
1252+
-------
1253+
A 1d bool array. len(result) == len(self)
1254+
"""
1255+
if filter.oper == filter.IsDefined:
1256+
col = col.astype(float)
1257+
return ~np.isnan(col)
1258+
1259+
return self._range_filter_to_indicator(filter, col, filter.min, filter.max)
1260+
1261+
def _string_filter_to_indicator(self, filter, col):
1262+
"""Return selection of rows matched by the given string filter.
1263+
1264+
Parameters
1265+
----------
1266+
filter: FilterString
1267+
col: np.ndarray
1268+
1269+
Returns
1270+
-------
1271+
A 1d bool array. len(result) == len(self)
1272+
"""
1273+
if filter.oper == filter.IsDefined:
1274+
return col.astype(bool)
1275+
1276+
col = col.astype(str)
1277+
fmin = filter.min or ""
1278+
fmax = filter.max or ""
1279+
1280+
if not filter.case_sensitive:
1281+
# convert all to lower case
1282+
col = np.char.lower(col)
1283+
fmin = fmin.lower()
1284+
fmax = fmax.lower()
1285+
1286+
if filter.oper == filter.Contains:
1287+
return np.fromiter((fmin in e for e in col),
1288+
dtype=bool)
1289+
if filter.oper == filter.StartsWith:
1290+
return np.fromiter((e.startswith(fmin) for e in col),
1291+
dtype=bool)
1292+
if filter.oper == filter.EndsWith:
1293+
return np.fromiter((e.endswith(fmin) for e in col),
1294+
dtype=bool)
1295+
1296+
return self._range_filter_to_indicator(filter, col, fmin, fmax)
1297+
1298+
@staticmethod
1299+
def _range_filter_to_indicator(filter, col, fmin, fmax):
1300+
if filter.oper == filter.Equal:
1301+
return col == fmin
1302+
if filter.oper == filter.NotEqual:
1303+
return col != fmin
1304+
if filter.oper == filter.Less:
1305+
return col < fmin
1306+
if filter.oper == filter.LessEqual:
1307+
return col <= fmin
1308+
if filter.oper == filter.Greater:
1309+
return col > fmin
1310+
if filter.oper == filter.GreaterEqual:
1311+
return col >= fmin
1312+
if filter.oper == filter.Between:
1313+
return (col >= fmin) * (col <= fmax)
1314+
if filter.oper == filter.Outside:
1315+
return (col < fmin) + (col > fmax)
1316+
1317+
raise TypeError("Invalid operator")
12541318

12551319
def _compute_basic_stats(self, columns=None,
12561320
include_metas=False, compute_variance=False):

Orange/tests/test_table.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,27 @@ def test_filter_values_nested(self):
881881
f = filter.Values([filter.Values([f1, f2], conjunction=False), f3])
882882
self.assertEqual(41, len(f(d)))
883883

884+
def test_filter_string_works_for_numeric_columns(self):
885+
var = StringVariable("s")
886+
data = Table(Domain([], metas=[var]), [[x] for x in range(21)])
887+
# 1, 2, 3, ..., 18, 19, 20
888+
889+
fs = filter.FilterString
890+
filters = [
891+
((fs.Greater, "5"), dict(rows=4)),
892+
# 6, 7, 8, 9
893+
((fs.Between, "15", "2"), dict(rows=6)),
894+
# 15, 16, 17, 18, 19, 2
895+
((fs.Contains, "2"), dict(rows=3)),
896+
# 2, 12, 20
897+
]
898+
899+
for args, expected in filters:
900+
f = fs(var, *args)
901+
filtered_data = filter.Values([f])(data)
902+
self.assertEqual(len(filtered_data), expected["rows"],
903+
"{} returned wrong number of rows".format(args))
904+
884905
def test_filter_value_continuous(self):
885906
d = data.Table("iris")
886907
col = d.X[:, 2]
@@ -1181,6 +1202,24 @@ def test_valueFilter_regex(self):
11811202
x = filter.Values([f])(d)
11821203
self.assertEqual(len(x), 7)
11831204

1205+
def test_valueFilter_stringList(self):
1206+
data = Table("zoo")
1207+
var = data.domain["name"]
1208+
1209+
fs = filter.FilterStringList
1210+
filters = [
1211+
((["swan", "tuna", "wasp"], True), dict(rows=3)),
1212+
((["swan", "tuna", "wasp"], False), dict(rows=3)),
1213+
((["WoRm", "TOad", "vOLe"], True), dict(rows=0)),
1214+
((["WoRm", "TOad", "vOLe"], False), dict(rows=3)),
1215+
]
1216+
1217+
for args, expected in filters:
1218+
f = fs(var, *args)
1219+
filtered_data = filter.Values([f])(data)
1220+
self.assertEqual(len(filtered_data), expected["rows"],
1221+
"{} returned wrong number of rows".format(args))
1222+
11841223
def test_table_dtypes(self):
11851224
table = data.Table("iris")
11861225
metas = np.hstack((table.metas, table.Y.reshape(len(table), 1)))

0 commit comments

Comments
 (0)