Skip to content

Commit 3d8d1fd

Browse files
Merge pull request #194 from TileDB-Inc/npapa/fix_kmeans
Use sklearn as default for kmeans until the C++ kmeans library is released in Docker images
2 parents 4981e25 + e9e2852 commit 3d8d1fd

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def ingest(
4747
storage_version: str = STORAGE_VERSION,
4848
verbose: bool = False,
4949
trace_id: Optional[str] = None,
50-
use_sklearn: bool = False,
50+
use_sklearn: bool = True,
5151
mode: Mode = Mode.LOCAL,
5252
**kwargs,
5353
):
@@ -129,7 +129,7 @@ def ingest(
129129
trace ID for logging, defaults to None
130130
use_sklearn: bool
131131
Whether to use scikit-learn's implementation of k-means clustering instead of
132-
tiledb.vector_search's. Defaults to false.
132+
tiledb.vector_search's. Defaults to true.
133133
mode: Mode
134134
execution mode, defaults to LOCAL use BATCH for distributed execution
135135
"""
@@ -933,14 +933,11 @@ def centralised_kmeans(
933933
config: Optional[Mapping[str, Any]] = None,
934934
verbose: bool = False,
935935
trace_id: Optional[str] = None,
936-
use_sklearn: bool = False
936+
use_sklearn: bool = True
937937
):
938938
from sklearn.cluster import KMeans
939939

940-
from tiledb.vector_search.module import (
941-
array_to_matrix,
942-
kmeans_fit,
943-
)
940+
from tiledb.vector_search.module import array_to_matrix
944941

945942
with tiledb.scope_ctx(ctx_or_config=config):
946943
logger = setup(config, verbose)
@@ -990,6 +987,7 @@ def centralised_kmeans(
990987
km.fit_predict(sample_vectors)
991988
centroids = np.transpose(np.array(km.cluster_centers_))
992989
else:
990+
from tiledb.vector_search.module import kmeans_fit
993991
centroids = kmeans_fit(partitions, init, max_iter, verbose, n_init, array_to_matrix(np.transpose(sample_vectors)))
994992
centroids = np.array(centroids) # TODO: why is this here?
995993
else:
@@ -1044,7 +1042,7 @@ def assign_points_and_partial_new_centroids(
10441042
config: Optional[Mapping[str, Any]] = None,
10451043
verbose: bool = False,
10461044
trace_id: Optional[str] = None,
1047-
use_sklearn: bool = False,
1045+
use_sklearn: bool = True,
10481046
):
10491047
import tiledb.cloud
10501048
from sklearn.cluster import KMeans
@@ -1145,7 +1143,7 @@ def update_centroids():
11451143
logger.debug("Assigning vectors to centroids")
11461144
if use_sklearn:
11471145
km = KMeans()
1148-
km.n_threads_ = threads
1146+
km._n_threads = threads
11491147
km.cluster_centers_ = centroids
11501148
assignments = km.predict(vectors)
11511149
else:
@@ -1692,7 +1690,7 @@ def create_ingestion_dag(
16921690
config: Optional[Mapping[str, Any]] = None,
16931691
verbose: bool = False,
16941692
trace_id: Optional[str] = None,
1695-
use_sklearn: bool = False,
1693+
use_sklearn: bool = True,
16961694
mode: Mode = Mode.LOCAL,
16971695
) -> dag.DAG:
16981696
if mode == Mode.BATCH:

apis/python/test/test_cloud.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,7 @@ def test_cloud_ivf_flat(self):
8383
partitions=partitions,
8484
input_vectors_per_work_item=5000,
8585
config=tiledb.cloud.Config().dict(),
86-
# TODO Re-enable.
87-
# This is temporarily disabled due to an incompatibility of new ingestion code and previous
88-
# UDF library releases.
89-
# mode=Mode.BATCH,
86+
mode=Mode.BATCH,
9087
)
9188

9289
tiledb_index_uri = groups.info(index_uri).tiledb_uri

0 commit comments

Comments
 (0)