@@ -47,7 +47,7 @@ def ingest(
4747    storage_version : str  =  STORAGE_VERSION ,
4848    verbose : bool  =  False ,
4949    trace_id : Optional [str ] =  None ,
50-     use_sklearn : bool  =  False ,
50+     use_sklearn : bool  =  True ,
5151    mode : Mode  =  Mode .LOCAL ,
5252    ** kwargs ,
5353):
@@ -129,7 +129,7 @@ def ingest(
129129        trace ID for logging, defaults to None 
130130    use_sklearn: bool 
131131        Whether to use scikit-learn's implementation of k-means clustering instead of 
132-         tiledb.vector_search's. Defaults to false . 
132+         tiledb.vector_search's. Defaults to true . 
133133    mode: Mode 
134134        execution mode, defaults to LOCAL use BATCH for distributed execution 
135135    """ 
@@ -933,14 +933,11 @@ def centralised_kmeans(
933933        config : Optional [Mapping [str , Any ]] =  None ,
934934        verbose : bool  =  False ,
935935        trace_id : Optional [str ] =  None ,
936-         use_sklearn : bool  =  False 
936+         use_sklearn : bool  =  True 
937937    ):
938938        from  sklearn .cluster  import  KMeans 
939939
940-         from  tiledb .vector_search .module  import  (
941-             array_to_matrix ,
942-             kmeans_fit ,
943-         )
940+         from  tiledb .vector_search .module  import  array_to_matrix 
944941
945942        with  tiledb .scope_ctx (ctx_or_config = config ):
946943            logger  =  setup (config , verbose )
@@ -990,6 +987,7 @@ def centralised_kmeans(
990987                    km .fit_predict (sample_vectors )
991988                    centroids  =  np .transpose (np .array (km .cluster_centers_ ))
992989                else :
990+                     from  tiledb .vector_search .module  import  kmeans_fit 
993991                    centroids  =  kmeans_fit (partitions , init , max_iter , verbose , n_init , array_to_matrix (np .transpose (sample_vectors )))
994992                    centroids  =  np .array (centroids ) # TODO: why is this here? 
995993            else :
@@ -1044,7 +1042,7 @@ def assign_points_and_partial_new_centroids(
10441042        config : Optional [Mapping [str , Any ]] =  None ,
10451043        verbose : bool  =  False ,
10461044        trace_id : Optional [str ] =  None ,
1047-         use_sklearn : bool  =  False ,
1045+         use_sklearn : bool  =  True ,
10481046    ):
10491047        import  tiledb .cloud 
10501048        from  sklearn .cluster  import  KMeans 
@@ -1145,7 +1143,7 @@ def update_centroids():
11451143            logger .debug ("Assigning vectors to centroids" )
11461144            if  use_sklearn :
11471145                km  =  KMeans ()
1148-                 km .n_threads_  =  threads 
1146+                 km ._n_threads  =  threads 
11491147                km .cluster_centers_  =  centroids 
11501148                assignments  =  km .predict (vectors )
11511149            else :
@@ -1692,7 +1690,7 @@ def create_ingestion_dag(
16921690        config : Optional [Mapping [str , Any ]] =  None ,
16931691        verbose : bool  =  False ,
16941692        trace_id : Optional [str ] =  None ,
1695-         use_sklearn : bool  =  False ,
1693+         use_sklearn : bool  =  True ,
16961694        mode : Mode  =  Mode .LOCAL ,
16971695    ) ->  dag .DAG :
16981696        if  mode  ==  Mode .BATCH :
0 commit comments