diff --git a/Python/phate/sgd_mds.py b/Python/phate/sgd_mds.py index 63f3e58..a2cfaa3 100644 --- a/Python/phate/sgd_mds.py +++ b/Python/phate/sgd_mds.py @@ -25,17 +25,33 @@ def sgd_mds( Randomly samples pairs at each iteration - simple and effective! This approach is 7-10x faster than SMACOF while maintaining excellent quality. + Uses an exponential learning rate schedule inspired by the s_gd2 paper, + with automatic scaling to account for minibatch sampling. + Parameters ---------- - D : distance matrix [n, n] - n_components : output dimensions - learning_rate : initial learning rate - n_iter : number of iterations - init : initial embedding (from classic MDS) - random_state : random state - verbose : verbosity level - pairs_per_iter : number of pairs to sample per iteration - If None, uses n * log(n) pairs per iteration + D : array-like, shape (n_samples, n_samples) + Distance matrix + n_components : int, default=2 + Number of dimensions for the output embedding + learning_rate : float, default=0.001 + Base learning rate (will be scaled automatically for minibatches) + n_iter : int, default=500 + Maximum number of iterations + init : array-like, shape (n_samples, n_components), optional + Initial embedding (e.g., from classical MDS) + random_state : int or RandomState, optional + Random state for reproducibility + verbose : int, default=0 + Verbosity level (0=silent, 1=progress, 2=debug) + pairs_per_iter : int, optional + Number of pairs to sample per iteration. + If None, uses n * log(n) pairs per iteration. + + Returns + ------- + Y : array-like, shape (n_samples, n_components) + Embedded coordinates """ if random_state is None: rng = np.random.RandomState() @@ -71,10 +87,40 @@ def sgd_mds( if verbose > 0: _logger.log_debug(f"SGD-MDS: sampling {pairs_per_iter} pairs per iteration") + # Exponential learning rate schedule (inspired by s_gd2 paper) + # Reference: https://github.com/jxz12/s_gd2 + total_pairs = n_samples * (n_samples - 1) / 2 + sampling_ratio = pairs_per_iter / total_pairs + + # s_gd2 uses: eta_max = 1/min(w), eta_min = eps/max(w) for weighted MDS + # For uniform weights: eta_max = 1, eta_min = eps (eps=0.01) + # + # Since we're doing minibatch SGD (sampling a fraction of pairs), we compensate + # by scaling the learning rate by sqrt(1/sampling_ratio). This balances: + # - Higher gradient variance from smaller batches (suggests smaller LR) + # - More iterations with partial gradients (allows larger LR) + # The sqrt scaling is a standard heuristic from SGD theory. + batch_scale = np.sqrt(1.0 / sampling_ratio) + eta_max = learning_rate * batch_scale + eta_min = learning_rate * 0.01 * batch_scale # eps = 0.01 as in s_gd2 + lambd = np.log(eta_max / eta_min) / max(n_iter - 1, 1) + + if verbose > 1: + _logger.log_debug( + f"SGD-MDS setup: n_samples={n_samples}, pairs_per_iter={pairs_per_iter}, " + f"sampling_ratio={sampling_ratio:.6f}, batch_scale={batch_scale:.2f}" + ) + _logger.log_debug( + f"Learning rate schedule: eta_max={eta_max:.6f}, eta_min={eta_min:.6f}, " + f"lambda={lambd:.6f}, n_iter={n_iter}" + ) + + prev_stress = None + stress_history = [] + for iteration in range(n_iter): - # Learning rate decay - progress = iteration / max(n_iter - 1, 1) - lr = learning_rate * (1 - progress) ** 0.8 + # Exponential decay schedule (s_gd2 style) + lr = eta_max * np.exp(-lambd * iteration) # Randomly sample pairs (without replacement for efficiency) # Sample from upper triangle to avoid double-counting @@ -112,10 +158,44 @@ def sgd_mds( # Update Y = Y - lr * gradients + # Compute stress for convergence checking + stress = np.sum(errors ** 2) / len(errors) # Normalized by number of samples + stress_history.append(stress) + if verbose > 0 and iteration % 100 == 0: - stress = np.sum(errors ** 2) _logger.log_debug(f"Iter {iteration}: stress={stress:.6f}, lr={lr:.6f}") + if verbose > 1 and (iteration % 100 == 0 or iteration < 5): + _logger.log_debug( + f"Iter {iteration}: stress={stress:.6f}, lr={lr:.6e}, " + f"mean(|grad|)={np.mean(np.abs(gradients)):.6e}, " + f"mean(|Y|)={np.mean(np.abs(Y)):.6e}" + ) + + # Check for convergence (relative change in stress) + if iteration > 0: + rel_change = abs(stress - prev_stress) / (prev_stress + 1e-10) + if rel_change < 1e-6 and iteration > 50: + if verbose > 0: + _logger.log_info( + f"Converged at iteration {iteration} (rel_change={rel_change:.2e})" + ) + break + prev_stress = stress + + # Check if converged properly + if len(stress_history) > 10: + # Check if stress is still decreasing significantly in last 10% of iterations + last_10pct = max(1, len(stress_history) // 10) + recent_stress = stress_history[-last_10pct:] + if len(recent_stress) > 1: + stress_trend = (recent_stress[-1] - recent_stress[0]) / (recent_stress[0] + 1e-10) + if abs(stress_trend) > 0.01: # Still changing by more than 1% + _logger.log_warning( + f"SGD-MDS may not have converged: stress changed by {stress_trend*100:.1f}% " + f"in final iterations. Consider increasing n_iter or adjusting learning_rate." + ) + # Rescale back to original if D_max > 0: Y = Y * D_max @@ -130,10 +210,32 @@ def sgd_mds_metric( random_state=None, verbose=0, ): - """Auto-tuned SGD-MDS with optimal parameters for different data sizes""" + """Auto-tuned SGD-MDS with optimal parameters for different data sizes + + Automatically selects the number of iterations and pairs per iteration + based on the dataset size. + + Parameters + ---------- + D : array-like, shape (n_samples, n_samples) + Distance matrix + n_components : int, default=2 + Number of dimensions for the output embedding + init : array-like, shape (n_samples, n_components), optional + Initial embedding (e.g., from classical MDS) + random_state : int or RandomState, optional + Random state for reproducibility + verbose : int, default=0 + Verbosity level (0=silent, 1=progress, 2=debug) + + Returns + ------- + Y : array-like, shape (n_samples, n_components) + Embedded coordinates + """ n_samples = D.shape[0] - # Auto-tune: more iterations for larger n + # Auto-tune: more iterations for larger datasets if n_samples < 1000: n_iter = 300 pairs_per_iter = n_samples * n_samples // 10 # 10% of all pairs