Skip to content

Commit 25c92df

Browse files
committed
improve type inference
1 parent 1d8b32e commit 25c92df

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

data_algebra/util.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,21 @@ def guess_carried_scalar_type(col):
7575
:param col: column or scalar to inspect
7676
:return: type of first non-None entry, if any , else type(None)
7777
"""
78+
# check for scalars first
7879
ct = map_type_to_canonical(type(col))
7980
if ct in {str, int, float, bool, type(None), numpy.int64, numpy.float64,
8081
datetime.datetime, datetime.date, datetime.timedelta}:
8182
return ct
83+
# look at a list or Series
84+
if isinstance(col, data_algebra.default_data_model.pd.core.series.Series):
85+
col = col.values
8286
if len(col) < 1:
8387
return type(None)
84-
idx = col.notna().idxmax()
85-
if idx is None:
86-
return map_type_to_canonical(type(col[0]))
87-
return map_type_to_canonical(type(col[idx]))
88+
good_idx = numpy.where(numpy.logical_not(data_algebra.default_data_model.pd.isna(col)))[0]
89+
test_idx = 0
90+
if len(good_idx) > 0:
91+
test_idx = good_idx[0]
92+
return map_type_to_canonical(type(col[test_idx]))
8893

8994

9095
def guess_column_types(d, *, columns=None):
@@ -106,7 +111,7 @@ def guess_column_types(d, *, columns=None):
106111
res = dict()
107112
for c in columns:
108113
gt = guess_carried_scalar_type(d[c])
109-
if (gt is None) or (not isinstance(gt, type)) or str(gt).endswith('.Series\'>'):
114+
if (gt is None) or (not isinstance(gt, type)) or gt == data_algebra.default_data_model.pd.core.series.Series:
110115
# pandas.concat() poisons types with Series, don't allow that
111116
return dict()
112117
res[c] = gt

tests/test_type_check_problem.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11

2-
import numpy
3-
42
import data_algebra
53
from data_algebra.data_ops import *
64
import data_algebra.util
75
import data_algebra.test_util
86

9-
import pytest
10-
117

128
def test_type_check_problem_1():
139
d = data_algebra.default_data_model.pd.DataFrame({
@@ -16,8 +12,13 @@ def test_type_check_problem_1():
1612
assert not isinstance(d['x'][0], data_algebra.default_data_model.pd.core.series.Series)
1713
d2 = data_algebra.default_data_model.pd.concat([d, d])
1814
# pandas concat mucks columns into series (reset_index() fixes that)
15+
# check that we see that problem
1916
assert isinstance(d2['x'][0], data_algebra.default_data_model.pd.core.series.Series)
17+
# reset index can clear the problem
2018
d3 = d2.reset_index(drop=True, inplace=False)
2119
assert not isinstance(d3['x'][0], data_algebra.default_data_model.pd.core.series.Series)
20+
# see what type inspection we get
2221
td2 = describe_table(d2, table_name='d2')
23-
assert (td2.column_types is None) or (len(td2.column_types) == 0) # don't pick up Series types!
22+
td3 = describe_table(d2, table_name='d3')
23+
assert td3.column_types['x'] == int
24+
assert td2.column_types == td3.column_types # don't pick up Series types!

0 commit comments

Comments
 (0)