@@ -400,7 +400,7 @@ def __init__(self, n_components=2, perplexity=30, learning_rate=200,
400400 self .callbacks_every_iters = callbacks_every_iters
401401 self .random_state = random_state
402402
403- def fit (self , X : np . ndarray , Y : np . ndarray = None ) -> openTSNE . TSNEEmbedding :
403+ def compute_affinities (self , X ) :
404404 # Sparse data are not supported
405405 if sp .issparse (X ):
406406 raise TypeError (
@@ -415,41 +415,75 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
415415 if not isinstance (self .perplexity , Iterable ):
416416 raise ValueError (
417417 "Perplexity should be an instance of `Iterable`, `%s` "
418- "given." % type (self .perplexity ).__name__ )
418+ "given." % type (self .perplexity ).__name__
419+ )
419420 affinities = openTSNE .affinity .Multiscale (
420- X , perplexities = self .perplexity , metric = self .metric ,
421- method = self .neighbors , random_state = self .random_state , n_jobs = self .n_jobs )
421+ X ,
422+ perplexities = self .perplexity ,
423+ metric = self .metric ,
424+ method = self .neighbors ,
425+ random_state = self .random_state ,
426+ n_jobs = self .n_jobs ,
427+ )
422428 else :
423429 if isinstance (self .perplexity , Iterable ):
424430 raise ValueError (
425431 "Perplexity should be an instance of `float`, `%s` "
426- "given." % type (self .perplexity ).__name__ )
432+ "given." % type (self .perplexity ).__name__
433+ )
427434 affinities = openTSNE .affinity .PerplexityBasedNN (
428- X , perplexity = self .perplexity , metric = self .metric ,
429- method = self .neighbors , random_state = self .random_state , n_jobs = self .n_jobs )
435+ X ,
436+ perplexity = self .perplexity ,
437+ metric = self .metric ,
438+ method = self .neighbors ,
439+ random_state = self .random_state ,
440+ n_jobs = self .n_jobs ,
441+ )
430442
431- # Create an initial embedding
443+ return affinities
444+
445+ def compute_initialization (self , X ):
446+ # Compute the initial positions of individual points
432447 if isinstance (self .initialization , np .ndarray ):
433448 initialization = self .initialization
434449 elif self .initialization == "pca" :
435450 initialization = openTSNE .initialization .pca (
436- X , self .n_components , random_state = self .random_state )
451+ X , self .n_components , random_state = self .random_state
452+ )
437453 elif self .initialization == "random" :
438454 initialization = openTSNE .initialization .random (
439- X , self .n_components , random_state = self .random_state )
455+ X , self .n_components , random_state = self .random_state
456+ )
440457 else :
441458 raise ValueError (
442459 "Invalid initialization `%s`. Please use either `pca` or "
443- "`random` or provide a numpy array." % self .initialization )
460+ "`random` or provide a numpy array." % self .initialization
461+ )
444462
445- embedding = openTSNE .TSNEEmbedding (
446- initialization , affinities , learning_rate = self .learning_rate ,
447- theta = self .theta , min_num_intervals = self .min_num_intervals ,
448- ints_in_interval = self .ints_in_interval , n_jobs = self .n_jobs ,
463+ return initialization
464+
465+ def prepare_embedding (self , affinities , initialization ):
466+ """Prepare an embedding object with appropriate parameters, given some
467+ affinities and initialization."""
468+ return openTSNE .TSNEEmbedding (
469+ initialization ,
470+ affinities ,
471+ learning_rate = self .learning_rate ,
472+ theta = self .theta ,
473+ min_num_intervals = self .min_num_intervals ,
474+ ints_in_interval = self .ints_in_interval ,
475+ n_jobs = self .n_jobs ,
449476 negative_gradient_method = self .negative_gradient_method ,
450- callbacks = self .callbacks , callbacks_every_iters = self .callbacks_every_iters ,
477+ callbacks = self .callbacks ,
478+ callbacks_every_iters = self .callbacks_every_iters ,
451479 )
452480
481+ def fit (self , X : np .ndarray , Y : np .ndarray = None ) -> openTSNE .TSNEEmbedding :
482+ # Compute affinities and initial positions and prepare the embedding object
483+ affinities = self .compute_affinities (X )
484+ initialization = self .compute_initialization (X )
485+ embedding = self .prepare_embedding (affinities , initialization )
486+
453487 # Run standard t-SNE optimization
454488 embedding .optimize (
455489 n_iter = self .early_exaggeration_iter , exaggeration = self .early_exaggeration ,
0 commit comments