Skip to content

Commit 468b329

Browse files
committed
introduce variable preprocessed to make sure scaling and NaN-handling don't interfere
1 parent 13840f0 commit 468b329

File tree

1 file changed

+28
-16
lines changed

1 file changed

+28
-16
lines changed

src/napari_clusters_plotter/_algorithms.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ def _reduce_pca(
2020
non_nan_data = data.dropna()
2121

2222
if scale:
23-
non_nan_data = StandardScaler().fit_transform(non_nan_data.values)
23+
preprocessed = StandardScaler().fit_transform(non_nan_data.values)
2424
else:
25-
non_nan_data = non_nan_data.values
25+
preprocessed = non_nan_data.values
2626

2727
pca = PCA(n_components=n_components)
28-
pca.fit(non_nan_data)
29-
reduced_data = pca.transform(non_nan_data)
28+
pca.fit(preprocessed)
29+
reduced_data = pca.transform(preprocessed)
3030

3131
# Add NaN rows back
3232
result = pd.DataFrame(index=data.index, columns=range(n_components))
@@ -58,9 +58,11 @@ def _reduce_tsne(
5858
non_nan_data = data.dropna()
5959

6060
if scale:
61-
non_nan_data = StandardScaler().fit_transform(non_nan_data)
61+
preprocessed = StandardScaler().fit_transform(non_nan_data)
62+
else:
63+
preprocessed = non_nan_data.values
6264
tsne = TSNE(n_components=n_components, perplexity=perplexity)
63-
reduced_data = tsne.fit_transform(non_nan_data)
65+
reduced_data = tsne.fit_transform(preprocessed)
6466

6567
# Add NaN rows back
6668
result = pd.DataFrame(index=data.index, columns=range(n_components))
@@ -92,10 +94,12 @@ def _reduce_umap(
9294
non_nan_data = data.dropna()
9395

9496
if scale:
95-
non_nan_data = StandardScaler().fit_transform(non_nan_data)
97+
preprocessed = StandardScaler().fit_transform(non_nan_data)
98+
else:
99+
preprocessed = non_nan_data.values
96100

97101
reducer = umap.UMAP(n_components=n_components, n_neighbors=n_neighbors)
98-
reduced_data = reducer.fit_transform(non_nan_data)
102+
reduced_data = reducer.fit_transform(preprocessed)
99103

100104
# Add NaN rows back
101105
result = pd.DataFrame(index=data.index, columns=range(n_components))
@@ -124,10 +128,12 @@ def _cluster_kmeans(
124128
non_nan_data = data.dropna()
125129

126130
if scale:
127-
non_nan_data = StandardScaler().fit_transform(non_nan_data)
131+
preprocessed = StandardScaler().fit_transform(non_nan_data)
132+
else:
133+
preprocessed = non_nan_data.values
128134

129135
kmeans = KMeans(n_clusters=n_clusters)
130-
clusters = kmeans.fit_predict(non_nan_data)
136+
clusters = kmeans.fit_predict(preprocessed)
131137

132138
# Add NaN rows back
133139
result = pd.Series(index=data.index, dtype=int)
@@ -162,12 +168,14 @@ def _cluster_hdbscan(
162168
non_nan_data = data.dropna()
163169

164170
if scale:
165-
non_nan_data = StandardScaler().fit_transform(non_nan_data)
171+
preprocessed = StandardScaler().fit_transform(non_nan_data)
172+
else:
173+
preprocessed = non_nan_data.values
166174

167175
clusterer = HDBSCAN(
168176
min_cluster_size=min_cluster_size, min_samples=min_samples
169177
)
170-
clusters = clusterer.fit_predict(non_nan_data)
178+
clusters = clusterer.fit_predict(preprocessed)
171179

172180
# Add NaN rows back
173181
result = pd.Series(index=data.index, dtype=int)
@@ -196,10 +204,12 @@ def _cluster_gaussian_mixture(
196204
non_nan_data = data.dropna()
197205

198206
if scale:
199-
non_nan_data = StandardScaler().fit_transform(non_nan_data)
207+
preprocessed = StandardScaler().fit_transform(non_nan_data)
208+
else:
209+
preprocessed = non_nan_data.values
200210

201211
gmm = GaussianMixture(n_components=n_components)
202-
clusters = gmm.fit_predict(non_nan_data)
212+
clusters = gmm.fit_predict(preprocessed)
203213

204214
# Add NaN rows back
205215
result = pd.Series(index=data.index, dtype=int)
@@ -228,10 +238,12 @@ def _cluster_spectral(
228238
non_nan_data = data.dropna()
229239

230240
if scale:
231-
non_nan_data = StandardScaler().fit_transform(non_nan_data)
241+
preprocessed = StandardScaler().fit_transform(non_nan_data)
242+
else:
243+
preprocessed = non_nan_data.values
232244

233245
clusterer = SpectralClustering(n_clusters=n_clusters)
234-
clusters = clusterer.fit_predict(non_nan_data)
246+
clusters = clusterer.fit_predict(preprocessed)
235247

236248
# Add NaN rows back
237249
result = pd.Series(index=data.index, dtype=int)

0 commit comments

Comments
 (0)