Skip to content

Commit 68c9bab

Browse files
ENH: Support user categories in OneHotEncoder (#727)
* ENH: Support user categories in OneHotEncoder Allows for ```python ohe = OneHotEncoder(categories=[['a', 'b'], ['c', 'd']]) ``` Previously, we required inputs to be CategoricalDtype for dataframes. Closes #726 * bump minimum
1 parent c64b217 commit 68c9bab

File tree

6 files changed

+90
-44
lines changed

6 files changed

+90
-44
lines changed

ci/environment-3.6.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies:
1717
- numpy ==1.17.3
1818
- numpydoc
1919
- packaging
20-
- pandas =0.23.4
20+
- pandas =0.24.2
2121
- psutil
2222
- pytest
2323
- pytest-cov

dask_ml/preprocessing/_encoders.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -191,19 +191,21 @@ def _fit(self, X: Union[ArrayLike, DataFrameType], handle_unknown: str = "error"
191191
self.categories_.append(cats)
192192
self.dtypes_.append(None)
193193
else:
194-
if not (X.dtypes == "category").all():
195-
raise ValueError("All columns must be Categorical dtype.")
196-
if self.categories == "auto":
197-
for col in X.columns:
198-
Xi = X[col]
199-
cats = _encode(Xi, uniques=Xi.cat.categories)
200-
self.categories_.append(cats)
201-
self.dtypes_.append(Xi.dtype)
202-
else:
203-
raise ValueError(
204-
"Cannot specify 'categories' with DataFrame input. "
205-
"Use a categorical dtype instead."
206-
)
194+
for i in range(len(X.columns)):
195+
Xi = X.iloc[:, i]
196+
if self.categories != "auto":
197+
categories = self.categories[i]
198+
Xi = Xi.astype(pd.CategoricalDtype(categories))
199+
else:
200+
if not pd.api.types.is_categorical_dtype(Xi.dtype):
201+
raise ValueError(
202+
"All columns must be Categorical dtype when "
203+
"'categories=\"auto\"'."
204+
)
205+
206+
cats = _encode(Xi, uniques=Xi.cat.categories)
207+
self.categories_.append(cats)
208+
self.dtypes_.append(Xi.dtype)
207209

208210
self.categories_ = dask.compute(self.categories_)[0]
209211

@@ -250,23 +252,25 @@ def _transform(
250252
else:
251253
import dask.dataframe as dd
252254

253-
# Validate that all are categorical.
254-
if not (X.dtypes == "category").all():
255-
raise ValueError("Must be all categorical.")
255+
X = X.copy()
256256

257257
if not len(X.columns) == len(self.categories_):
258258
raise ValueError(
259259
"Number of columns ({}) does not match number "
260260
"of categories_ ({})".format(len(X.columns), len(self.categories_))
261261
)
262262

263-
for col, dtype in zip(X.columns, self.dtypes_):
264-
if not (X[col].dtype == dtype):
263+
for i, (col, dtype) in enumerate(zip(X.columns, self.dtypes_)):
264+
Xi = X.iloc[:, i]
265+
if not pd.api.types.is_categorical_dtype(Xi.dtype):
266+
Xi = Xi.astype(dtype)
267+
X[col] = Xi
268+
269+
if Xi.dtype != dtype:
265270
raise ValueError(
266-
"Different CategoricalDtype for fit and "
267-
"transform. '{}' != {}'".format(dtype, X[col].dtype)
271+
"Different CategoricalDtype for fit and transform. "
272+
"{!r} != {!r}".format(Xi.dtype, dtype)
268273
)
269-
270274
return dd.get_dummies(X, sparse=self.sparse, dtype=self.dtype)
271275

272276
return X

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"distributed>=2.4.0",
1616
"numba",
1717
"numpy>=1.17.3",
18-
"pandas>=0.23.4",
18+
"pandas>=0.24.2",
1919
"scikit-learn>=0.23",
2020
"scipy",
2121
"dask-glm>=0.2.0",

tests/preprocessing/test_encoders.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,31 @@ def test_onehotencoder_drop_raises():
117117
dask_ml.preprocessing.OneHotEncoder(drop="first")
118118

119119

120+
def test_onehotencoder_dataframe_with_categories():
121+
# https://github.com/dask/dask-ml/issues/726
122+
enc = dask_ml.preprocessing.OneHotEncoder(
123+
categories=[["a", "b", "c"], ["a", "b"]], sparse=False
124+
)
125+
ddf = dd.from_pandas(
126+
pd.DataFrame({"A": ["a", "b", "b", "a"], "B": ["a", "b", "b", "b"]}),
127+
npartitions=1,
128+
)
129+
result = enc.fit_transform(ddf)
130+
expected = dd.from_pandas(
131+
pd.DataFrame(
132+
{
133+
"A_a": [1, 0, 0, 1],
134+
"A_b": [0, 1, 1, 0],
135+
"A_c": [0, 0, 0, 0],
136+
"B_a": [1, 0, 0, 0],
137+
"B_b": [0, 0, 0, 0],
138+
}
139+
),
140+
npartitions=1,
141+
)
142+
assert_estimator_equal(result, expected)
143+
144+
120145
def test_handles_numpy():
121146
enc = dask_ml.preprocessing.OneHotEncoder()
122147
enc.fit(X)
@@ -132,26 +157,26 @@ def test_dataframe_requires_all_categorical(data):
132157
assert e.match("All columns must be Categorical dtype")
133158

134159

135-
@pytest.mark.parametrize("data", [df, ddf])
136-
def test_dataframe_prohibits_categories(data):
137-
enc = dask_ml.preprocessing.OneHotEncoder(categories=[["a", "b"]])
138-
with pytest.raises(ValueError) as e:
139-
enc.fit(data)
140-
141-
assert e.match("Cannot specify 'categories'")
142-
143-
144160
def test_unknown_category_transform():
145161
df2 = ddf.copy()
146162
df2["A"] = ddf.A.cat.add_categories("new!")
147163

148164
enc = dask_ml.preprocessing.OneHotEncoder()
149165
enc.fit(ddf)
150166

151-
with pytest.raises(ValueError) as e:
167+
with pytest.raises(ValueError, match="Different CategoricalDtype"):
152168
enc.transform(df2)
153169

154-
assert e.match("Different CategoricalDtype for fit and transform")
170+
171+
def test_different_shape_raises():
172+
df2 = ddf.copy()
173+
df2["B"] = ddf.A.cat.add_categories("new!")
174+
175+
enc = dask_ml.preprocessing.OneHotEncoder()
176+
enc.fit(ddf)
177+
178+
with pytest.raises(ValueError, match="Number of columns"):
179+
enc.transform(df2)
155180

156181

157182
@pytest.mark.skipif(not DASK_2_20_0, reason="Fixed in Dask 2.20.0")

tests/test_incremental_pca.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
@pytest.mark.parametrize("svd_solver", ["full", "auto", "randomized"])
2424
@pytest.mark.parametrize("batch_number", [3, 10])
25+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
2526
def test_compare_with_sklearn(svd_solver, batch_number):
2627
X = iris.data
2728
X_da = da.from_array(X, chunks=(3, -1))
@@ -52,14 +53,15 @@ def test_compare_with_sklearn(svd_solver, batch_number):
5253

5354

5455
@pytest.mark.parametrize("svd_solver", ["full", "auto", "randomized"])
56+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
5557
def test_incremental_pca(svd_solver):
5658
# Incremental PCA on dense arrays.
5759
X = iris.data
5860
X = da.from_array(X, chunks=(3, -1))
5961
batch_size = X.shape[0] // 3
6062
ipca = IncrementalPCA(n_components=2, batch_size=batch_size, svd_solver=svd_solver)
6163
pca = PCA(n_components=2, svd_solver=svd_solver)
62-
pca.fit_transform(X)
64+
pca.fit_transform(X.compute())
6365

6466
X_transformed = ipca.fit_transform(X)
6567

@@ -87,6 +89,7 @@ def test_incremental_pca(svd_solver):
8789
)
8890

8991

92+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
9093
def test_incremental_pca_check_projection():
9194
# Test that the projection of data is correct.
9295
rng = np.random.RandomState(1999)
@@ -111,6 +114,7 @@ def test_incremental_pca_check_projection():
111114
assert_almost_equal(np.abs(Yt[0][0]), 1.0, 1)
112115

113116

117+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
114118
def test_incremental_pca_inverse():
115119
# Test that the projection of data can be inverted.
116120
rng = np.random.RandomState(1999)
@@ -154,6 +158,7 @@ def test_incremental_pca_validation():
154158
IncrementalPCA(n_components=n_components).partial_fit(X)
155159

156160

161+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
157162
def test_n_components_none():
158163
# Ensures that n_components == None is handled correctly
159164
rng = np.random.RandomState(1999)
@@ -173,6 +178,7 @@ def test_n_components_none():
173178
assert ipca.n_components_ == ipca.components_.shape[0]
174179

175180

181+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
176182
def test_incremental_pca_set_params():
177183
# Test that components_ sign is stable over batch sizes.
178184
rng = np.random.RandomState(1999)
@@ -200,6 +206,7 @@ def test_incremental_pca_set_params():
200206
ipca.partial_fit(X)
201207

202208

209+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
203210
def test_incremental_pca_num_features_change():
204211
# Test that changing n_components will raise an error.
205212
rng = np.random.RandomState(1999)
@@ -215,6 +222,7 @@ def test_incremental_pca_num_features_change():
215222
ipca.partial_fit(X2)
216223

217224

225+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
218226
def test_incremental_pca_batch_signs():
219227
# Test that components_ sign is stable over batch sizes.
220228
rng = np.random.RandomState(1999)
@@ -232,6 +240,7 @@ def test_incremental_pca_batch_signs():
232240
assert_almost_equal(np.sign(i), np.sign(j), decimal=6)
233241

234242

243+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
235244
def test_incremental_pca_batch_values():
236245
# Test that components_ values are stable over batch sizes.
237246
rng = np.random.RandomState(1999)
@@ -249,6 +258,7 @@ def test_incremental_pca_batch_values():
249258
assert_almost_equal(i, j, decimal=1)
250259

251260

261+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
252262
def test_incremental_pca_batch_rank():
253263
# Test sample size in each batch is always larger or equal to n_components
254264
rng = np.random.RandomState(1999)
@@ -266,6 +276,7 @@ def test_incremental_pca_batch_rank():
266276
assert_allclose_dense_sparse(components_i, components_j)
267277

268278

279+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
269280
def test_incremental_pca_partial_fit():
270281
# Test that fit and partial_fit get equivalent results.
271282
rng = np.random.RandomState(1999)
@@ -288,12 +299,13 @@ def test_incremental_pca_partial_fit():
288299

289300

290301
@pytest.mark.parametrize("svd_solver", ["full", "auto", "randomized"])
302+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
291303
def test_incremental_pca_against_pca_iris(svd_solver):
292304
# Test that IncrementalPCA and PCA are approximate (to a sign flip).
293305
X = iris.data
294306
X = da.from_array(X, chunks=[50, -1])
295307

296-
Y_pca = PCA(n_components=2, svd_solver=svd_solver).fit_transform(X)
308+
Y_pca = PCA(n_components=2, svd_solver=svd_solver).fit_transform(X.compute())
297309
Y_ipca = IncrementalPCA(
298310
n_components=2, batch_size=25, svd_solver=svd_solver
299311
).fit_transform(X)
@@ -302,6 +314,7 @@ def test_incremental_pca_against_pca_iris(svd_solver):
302314

303315

304316
@pytest.mark.parametrize("svd_solver", ["full", "auto", "randomized"])
317+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
305318
def test_incremental_pca_against_pca_random_data(svd_solver):
306319
# Test that IncrementalPCA and PCA are approximate (to a sign flip).
307320
rng = np.random.RandomState(1999)
@@ -310,7 +323,7 @@ def test_incremental_pca_against_pca_random_data(svd_solver):
310323
X = rng.randn(n_samples, n_features) + 5 * rng.rand(1, n_features)
311324
X = da.from_array(X, chunks=[40, -1])
312325

313-
Y_pca = PCA(n_components=3, svd_solver=svd_solver).fit_transform(X)
326+
Y_pca = PCA(n_components=3, svd_solver=svd_solver).fit_transform(X.compute())
314327
Y_ipca = IncrementalPCA(
315328
n_components=3, batch_size=25, svd_solver=svd_solver
316329
).fit_transform(X)
@@ -319,6 +332,7 @@ def test_incremental_pca_against_pca_random_data(svd_solver):
319332

320333

321334
@pytest.mark.parametrize("svd_solver", ["full", "auto", "randomized"])
335+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
322336
def test_explained_variances(svd_solver):
323337
# Test that PCA and IncrementalPCA calculations match
324338
X = datasets.make_low_rank_matrix(
@@ -328,7 +342,7 @@ def test_explained_variances(svd_solver):
328342
prec = 3
329343
n_samples, n_features = X.shape
330344
for nc in [None, 99]:
331-
pca = PCA(n_components=nc, svd_solver=svd_solver).fit(X)
345+
pca = PCA(n_components=nc, svd_solver=svd_solver).fit(X.compute())
332346
ipca = IncrementalPCA(
333347
n_components=nc, batch_size=100, svd_solver=svd_solver
334348
).fit(X)
@@ -342,6 +356,7 @@ def test_explained_variances(svd_solver):
342356

343357

344358
@pytest.mark.parametrize("svd_solver", ["full", "auto", "randomized"])
359+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
345360
def test_singular_values(svd_solver):
346361
# Check that the IncrementalPCA output has the correct singular values
347362

@@ -354,7 +369,7 @@ def test_singular_values(svd_solver):
354369
)
355370
X = da.from_array(X, chunks=[200, -1])
356371

357-
pca = PCA(n_components=10, svd_solver=svd_solver, random_state=rng).fit(X)
372+
pca = PCA(n_components=10, svd_solver=svd_solver, random_state=rng).fit(X.compute())
358373
ipca = IncrementalPCA(n_components=10, batch_size=100, svd_solver=svd_solver).fit(X)
359374
assert_array_almost_equal(pca.singular_values_, ipca.singular_values_, 2)
360375

@@ -389,7 +404,7 @@ def test_singular_values(svd_solver):
389404
pca = PCA(n_components=3, svd_solver=svd_solver, random_state=rng)
390405
ipca = IncrementalPCA(n_components=3, batch_size=100, svd_solver=svd_solver)
391406

392-
X_pca = pca.fit_transform(X)
407+
X_pca = pca.fit_transform(X.compute())
393408
X_pca /= np.sqrt(np.sum(X_pca ** 2.0, axis=0))
394409
X_pca[:, 0] *= 3.142
395410
X_pca[:, 1] *= 2.718
@@ -403,6 +418,7 @@ def test_singular_values(svd_solver):
403418

404419

405420
@pytest.mark.parametrize("svd_solver", ["full", "auto", "randomized"])
421+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
406422
def test_whitening(svd_solver):
407423
# Test that PCA and IncrementalPCA transforms match to sign flip.
408424
X = datasets.make_low_rank_matrix(
@@ -412,7 +428,7 @@ def test_whitening(svd_solver):
412428
prec = 3
413429
n_samples, n_features = X.shape
414430
for nc in [None, 9]:
415-
pca = PCA(whiten=True, n_components=nc, svd_solver=svd_solver).fit(X)
431+
pca = PCA(whiten=True, n_components=nc, svd_solver=svd_solver).fit(X.compute())
416432
ipca = IncrementalPCA(
417433
whiten=True, n_components=nc, batch_size=250, svd_solver=svd_solver
418434
).fit(X)
@@ -427,6 +443,7 @@ def test_whitening(svd_solver):
427443
assert_almost_equal(Xinv_pca, Xinv_ipca, decimal=prec)
428444

429445

446+
@pytest.mark.filterwarnings("ignore:invalid value:RuntimeWarning")
430447
def test_incremental_pca_partial_fit_float_division():
431448
# Test to ensure float division is used in all versions of Python
432449
# (non-regression test for issue #9489)

tests/test_parallel_post_fit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ def test_transform(kind):
9393
base = PCA(random_state=0)
9494
wrap = ParallelPostFit(PCA(random_state=0))
9595

96-
base.fit(X, y)
97-
wrap.fit(X, y)
96+
base.fit(*dask.compute(X, y))
97+
wrap.fit(*dask.compute(X, y))
9898

9999
assert_estimator_equal(wrap.estimator, base)
100100

101-
result = base.transform(X)
101+
result = base.transform(*dask.compute(X))
102102
expected = wrap.transform(X)
103103
assert_eq_ar(result, expected)
104104

0 commit comments

Comments
 (0)