@@ -265,7 +265,10 @@ def optimize(self, n_iter, inplace=False, propagate_exception=False, **kwargs):
265265 new_embedding = self .embedding_ .optimize (** kwargs )
266266 table = Table (self .embedding .domain , new_embedding .view (np .ndarray ),
267267 self .embedding .Y , self .embedding .metas )
268- return TSNEModel (new_embedding , table , self .pre_domain )
268+
269+ new_model = TSNEModel (new_embedding , table , self .pre_domain )
270+ new_model .name = self .name
271+ return new_model
269272
270273
271274class TSNE (Projector ):
@@ -400,7 +403,7 @@ def __init__(self, n_components=2, perplexity=30, learning_rate=200,
400403 self .callbacks_every_iters = callbacks_every_iters
401404 self .random_state = random_state
402405
403- def fit (self , X : np . ndarray , Y : np . ndarray = None ) -> openTSNE . TSNEEmbedding :
406+ def compute_affinities (self , X ) :
404407 # Sparse data are not supported
405408 if sp .issparse (X ):
406409 raise TypeError (
@@ -415,41 +418,75 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
415418 if not isinstance (self .perplexity , Iterable ):
416419 raise ValueError (
417420 "Perplexity should be an instance of `Iterable`, `%s` "
418- "given." % type (self .perplexity ).__name__ )
421+ "given." % type (self .perplexity ).__name__
422+ )
419423 affinities = openTSNE .affinity .Multiscale (
420- X , perplexities = self .perplexity , metric = self .metric ,
421- method = self .neighbors , random_state = self .random_state , n_jobs = self .n_jobs )
424+ X ,
425+ perplexities = self .perplexity ,
426+ metric = self .metric ,
427+ method = self .neighbors ,
428+ random_state = self .random_state ,
429+ n_jobs = self .n_jobs ,
430+ )
422431 else :
423432 if isinstance (self .perplexity , Iterable ):
424433 raise ValueError (
425434 "Perplexity should be an instance of `float`, `%s` "
426- "given." % type (self .perplexity ).__name__ )
435+ "given." % type (self .perplexity ).__name__
436+ )
427437 affinities = openTSNE .affinity .PerplexityBasedNN (
428- X , perplexity = self .perplexity , metric = self .metric ,
429- method = self .neighbors , random_state = self .random_state , n_jobs = self .n_jobs )
438+ X ,
439+ perplexity = self .perplexity ,
440+ metric = self .metric ,
441+ method = self .neighbors ,
442+ random_state = self .random_state ,
443+ n_jobs = self .n_jobs ,
444+ )
430445
431- # Create an initial embedding
446+ return affinities
447+
448+ def compute_initialization (self , X ):
449+ # Compute the initial positions of individual points
432450 if isinstance (self .initialization , np .ndarray ):
433451 initialization = self .initialization
434452 elif self .initialization == "pca" :
435453 initialization = openTSNE .initialization .pca (
436- X , self .n_components , random_state = self .random_state )
454+ X , self .n_components , random_state = self .random_state
455+ )
437456 elif self .initialization == "random" :
438457 initialization = openTSNE .initialization .random (
439- X , self .n_components , random_state = self .random_state )
458+ X , self .n_components , random_state = self .random_state
459+ )
440460 else :
441461 raise ValueError (
442462 "Invalid initialization `%s`. Please use either `pca` or "
443- "`random` or provide a numpy array." % self .initialization )
463+ "`random` or provide a numpy array." % self .initialization
464+ )
444465
445- embedding = openTSNE .TSNEEmbedding (
446- initialization , affinities , learning_rate = self .learning_rate ,
447- theta = self .theta , min_num_intervals = self .min_num_intervals ,
448- ints_in_interval = self .ints_in_interval , n_jobs = self .n_jobs ,
466+ return initialization
467+
468+ def prepare_embedding (self , affinities , initialization ):
469+ """Prepare an embedding object with appropriate parameters, given some
470+ affinities and initialization."""
471+ return openTSNE .TSNEEmbedding (
472+ initialization ,
473+ affinities ,
474+ learning_rate = self .learning_rate ,
475+ theta = self .theta ,
476+ min_num_intervals = self .min_num_intervals ,
477+ ints_in_interval = self .ints_in_interval ,
478+ n_jobs = self .n_jobs ,
449479 negative_gradient_method = self .negative_gradient_method ,
450- callbacks = self .callbacks , callbacks_every_iters = self .callbacks_every_iters ,
480+ callbacks = self .callbacks ,
481+ callbacks_every_iters = self .callbacks_every_iters ,
451482 )
452483
484+ def fit (self , X : np .ndarray , Y : np .ndarray = None ) -> openTSNE .TSNEEmbedding :
485+ # Compute affinities and initial positions and prepare the embedding object
486+ affinities = self .compute_affinities (X )
487+ initialization = self .compute_initialization (X )
488+ embedding = self .prepare_embedding (affinities , initialization )
489+
453490 # Run standard t-SNE optimization
454491 embedding .optimize (
455492 n_iter = self .early_exaggeration_iter , exaggeration = self .early_exaggeration ,
@@ -462,13 +499,7 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
462499
463500 return embedding
464501
465- def __call__ (self , data : Table ) -> TSNEModel :
466- # Preprocess the data - convert discrete to continuous
467- data = self .preprocess (data )
468-
469- # Run tSNE optimization
470- embedding = self .fit (data .X , data .Y )
471-
502+ def convert_embedding_to_model (self , data , embedding ):
472503 # The results should be accessible in an Orange table, which doesn't
473504 # need the full embedding attributes and is cast into a regular array
474505 n = self .n_components
@@ -484,6 +515,17 @@ def __call__(self, data: Table) -> TSNEModel:
484515
485516 return model
486517
518+ def __call__ (self , data : Table ) -> TSNEModel :
519+ # Preprocess the data - convert discrete to continuous
520+ data = self .preprocess (data )
521+
522+ # Run tSNE optimization
523+ embedding = self .fit (data .X , data .Y )
524+
525+ # Convert the t-SNE embedding object to a TSNEModel and prepare the
526+ # embedding table with t-SNE meta variables
527+ return self .convert_embedding_to_model (data , embedding )
528+
487529 @staticmethod
488530 def default_initialization (data , n_components = 2 , random_state = None ):
489531 return openTSNE .initialization .pca (
0 commit comments