Skip to content

Commit e970535

Browse files
committed
CSV Import: guess data types
1 parent 49e433c commit e970535

File tree

3 files changed

+128
-1
lines changed

3 files changed

+128
-1
lines changed

Orange/widgets/data/owcsvimport.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ def __set_error_state(self, err):
907907
if isinstance(err, UnicodeDecodeError):
908908
self.Error.encoding_error(exc_info=err)
909909
else:
910+
raise err
910911
self.Error.error(exc_info=err)
911912

912913
path = self.current_item().path()
@@ -1259,6 +1260,8 @@ def expand(ranges):
12591260
float_precision="round_trip",
12601261
**numbers_format_kwds
12611262
)
1263+
df = guess_types(df, dtypes, columns_ignored)
1264+
12621265
if columns_ignored:
12631266
# TODO: use 'usecols' parameter in `read_csv` call to
12641267
# avoid loading/parsing the columns
@@ -1270,6 +1273,86 @@ def expand(ranges):
12701273
return df
12711274

12721275

1276+
def guess_types(
1277+
df: pd.DataFrame, dtypes: Dict[int, str], columns_ignored: List[int]
1278+
) -> pd.DataFrame:
1279+
"""
1280+
Guess data type for variables according to values.
1281+
1282+
Parameters
1283+
----------
1284+
df
1285+
Data frame
1286+
dtypes
1287+
The dictionary with data types set by user. We will guess values only
1288+
for columns that does not have data type defined.
1289+
columns_ignored
1290+
List with indices of ignored columns. Ignored columns are skipped.
1291+
1292+
Returns
1293+
-------
1294+
A data frame with changed dtypes according to the strategy.
1295+
"""
1296+
for i, col in enumerate(df):
1297+
# only when automatic is set in widget dialog
1298+
if dtypes.get(i, None) is None and i not in columns_ignored:
1299+
df[col] = guess_data_type(df[col])
1300+
return df
1301+
1302+
1303+
def guess_data_type(col: pd.Series) -> pd.Series:
1304+
"""
1305+
Guess column types. Logic is same than in guess_data_type from io_utils
1306+
module. This function only change the dtype of the column such that later
1307+
correct Orange.data.variable is used.
1308+
Logic:
1309+
- if can converted to date-time (ISO) -> TimeVariable
1310+
- if numeric (only numbers)
1311+
- only values {0, 1} or {1, 2} -> DiscreteVariable
1312+
- else -> ContinuousVariable
1313+
- if not numbers:
1314+
- num_unique_values < len(data) ** 0.7 and < 100 -> DiscreteVariable
1315+
- else -> StringVariable
1316+
1317+
Parameters
1318+
----------
1319+
col
1320+
Data column
1321+
1322+
Returns
1323+
-------
1324+
Data column with correct dtype
1325+
"""
1326+
def parse_dates(s):
1327+
"""
1328+
This is an extremely fast approach to datetime parsing.
1329+
For large data, the same dates are often repeated. Rather than
1330+
re-parse these, we store all unique dates, parse them, and
1331+
use a lookup to convert all dates.
1332+
"""
1333+
try:
1334+
dates = {date: pd.to_datetime(date) for date in s.unique()}
1335+
except ValueError:
1336+
return None
1337+
return s.map(dates)
1338+
1339+
if pdtypes.is_numeric_dtype(col):
1340+
unique_values = col.unique()
1341+
if len(unique_values) <= 2 and (
1342+
len(np.setdiff1d(unique_values, [0, 1])) == 0
1343+
or len(np.setdiff1d(unique_values, [1, 2])) == 0):
1344+
return col.astype("category")
1345+
else: # object
1346+
# try parse as date - if None not a date
1347+
parsed_col = parse_dates(col)
1348+
if parsed_col is not None:
1349+
return parsed_col
1350+
unique_values = col.unique()
1351+
if len(unique_values) < 100 and len(unique_values) < len(col)**0.7:
1352+
return col.astype("category")
1353+
return col
1354+
1355+
12731356
def clear_stack_on_cancel(f):
12741357
"""
12751358
A decorator that catches the TaskState.UserCancelException exception
@@ -1465,7 +1548,8 @@ def pandas_to_table(df):
14651548
)
14661549
# Remap the coldata into the var.values order/set
14671550
coldata = pd.Categorical(
1468-
coldata, categories=var.values, ordered=coldata.ordered
1551+
coldata.astype("str"), categories=var.values,
1552+
ordered=coldata.ordered,
14691553
)
14701554
codes = coldata.codes
14711555
assert np.issubdtype(codes.dtype, np.integer)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
time numeric1 discrete1 numeric2 discrete2 string
2+
2020-05-05 1 0 a a
3+
2020-05-06 2 1 a b
4+
2020-05-07 3 0 a c
5+
2020-05-08 4 1 b d
6+
2020-05-09 5 1 b e

Orange/widgets/data/tests/test_owcsvimport.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from AnyQt.QtCore import QSettings
1414

15+
from Orange.data import DiscreteVariable, TimeVariable, ContinuousVariable, \
16+
StringVariable
1517
from Orange.tests import named_file
1618
from Orange.widgets.tests.base import WidgetTest, GuiTest
1719
from Orange.widgets.data import owcsvimport
@@ -127,6 +129,37 @@ def test_summary(self):
127129
output_sum.assert_called_with(len(output),
128130
format_summary_details(output))
129131

132+
data_csv_types_options = owcsvimport.Options(
133+
encoding="ascii", dialect=csv.excel_tab(),
134+
columntypes=[
135+
(range(0, 5), ColumnType.Auto),
136+
]
137+
)
138+
139+
def test_type_guessing(self):
140+
""" Check if correct column type is guessed when column type auto """
141+
dirname = os.path.dirname(__file__)
142+
path = os.path.join(dirname, "data-csv-types.tab")
143+
widget = self.create_widget(
144+
owcsvimport.OWCSVFileImport,
145+
stored_settings={
146+
"_session_items": [
147+
(path, self.data_csv_types_options.as_dict())
148+
]
149+
}
150+
)
151+
widget.commit()
152+
self.wait_until_finished(widget)
153+
output = self.get_output("Data", widget)
154+
domain = output.domain
155+
156+
self.assertIsInstance(domain["time"], TimeVariable)
157+
self.assertIsInstance(domain["discrete1"], DiscreteVariable)
158+
self.assertIsInstance(domain["discrete2"], DiscreteVariable)
159+
self.assertIsInstance(domain["numeric1"], ContinuousVariable)
160+
self.assertIsInstance(domain["numeric2"], ContinuousVariable)
161+
self.assertIsInstance(domain["string"], StringVariable)
162+
130163

131164
class TestImportDialog(GuiTest):
132165
def test_dialog(self):
@@ -253,3 +286,7 @@ class dialect(csv.excel):
253286
assert_array_equal(tb.X[:, 0], [np.nan, 0, np.nan])
254287
assert_array_equal(tb.X[:, 1], [0, np.nan, np.nan])
255288
assert_array_equal(tb.X[:, 2], [np.nan, 1, np.nan])
289+
290+
291+
if __name__ == "__main__":
292+
unittest.main()

0 commit comments

Comments
 (0)