Skip to content

Commit 1b9c79c

Browse files
authored
Modify _reduce_dimensionality to use fit_transform for more cuML support (#2416)
1 parent c6157fe commit 1b9c79c

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

bertopic/_bertopic.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3930,19 +3930,33 @@ def _reduce_dimensionality(
39303930
if partial_fit:
39313931
if hasattr(self.umap_model, "partial_fit"):
39323932
self.umap_model = self.umap_model.partial_fit(embeddings)
3933+
umap_embeddings = self.umap_model.transform(embeddings)
39333934
elif self.topic_representations_ is None:
3934-
self.umap_model.fit(embeddings)
3935+
if hasattr(self.umap_model, "fit_transform"):
3936+
umap_embeddings = self.umap_model.fit_transform(embeddings)
3937+
else:
3938+
self.umap_model.fit(embeddings)
3939+
umap_embeddings = self.umap_model.transform(embeddings)
3940+
else:
3941+
umap_embeddings = self.umap_model.transform(embeddings)
39353942

39363943
# Regular fit
39373944
else:
39383945
try:
39393946
# cuml umap needs y to be an numpy array
39403947
y = np.array(y) if y is not None else None
3941-
self.umap_model.fit(embeddings, y=y)
3948+
if hasattr(self.umap_model, "fit_transform"):
3949+
umap_embeddings = self.umap_model.fit_transform(embeddings, y=y)
3950+
else:
3951+
self.umap_model.fit(embeddings, y=y)
3952+
umap_embeddings = self.umap_model.transform(embeddings)
39423953
except TypeError:
3943-
self.umap_model.fit(embeddings)
3954+
if hasattr(self.umap_model, "fit_transform"):
3955+
umap_embeddings = self.umap_model.fit_transform(embeddings, y=y)
3956+
else:
3957+
self.umap_model.fit(embeddings, y=y)
3958+
umap_embeddings = self.umap_model.transform(embeddings)
39443959

3945-
umap_embeddings = self.umap_model.transform(embeddings)
39463960
logger.info("Dimensionality - Completed \u2713")
39473961
return np.nan_to_num(umap_embeddings)
39483962

bertopic/dimensionality/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class BaseDimensionalityReduction:
1919
```
2020
"""
2121

22-
def fit(self, X: np.ndarray = None):
22+
def fit(self, X: np.ndarray = None, y: np.ndarray = None):
2323
return self
2424

2525
def transform(self, X: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)