Skip to content

Commit 9cdfef7

Browse files
authored
Merge pull request #22 from zStupan/fix-dataset
Fix dataset
2 parents 89e3173 + be3e585 commit 9cdfef7

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

niaarm/dataset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,16 @@ def __analyse_types(self):
3939
min_value = col.min()
4040
max_value = col.max()
4141
unique_categories = None
42+
elif col.dtype == 'bool':
43+
self.data[head] = self.data[head].astype(int)
44+
self.transactions = self.data.values
45+
dtype = 'int'
46+
min_value = 0
47+
max_value = 1
48+
unique_categories = None
4249
else:
4350
dtype = "cat"
44-
unique_categories = sorted(col.astype(str).unique().tolist(), key=str.lower)
51+
unique_categories = sorted(col.unique().tolist(), key=str.lower)
4552
min_value = None
4653
max_value = None
4754

niaarm/tests/test_read_csv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_read_features(self):
2828
'float',
2929
'float',
3030
'int']
31-
31+
3232
data = Dataset(os.path.join(os.path.dirname(__file__), 'test_data', 'Abalone.csv'))
3333

3434
features = data.features
@@ -60,7 +60,7 @@ def test_read_features(self):
6060
minval = [None, 0]
6161
maxval = [None, 1]
6262
dtypes_a = ['cat', 'int']
63-
63+
6464
data = Dataset(os.path.join(os.path.dirname(__file__), 'test_data', 'wiki_test_case.csv'))
6565

6666
features = data.features

tests/test_niaarm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from niaarm.tests.conftest import pytest_configure
22

33
__all__ = ["pytest_configure"]
4-

0 commit comments

Comments
 (0)