Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 117 additions & 15 deletions Python/phate/sgd_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,33 @@ def sgd_mds(
Randomly samples pairs at each iteration - simple and effective!
This approach is 7-10x faster than SMACOF while maintaining excellent quality.

Uses an exponential learning rate schedule inspired by the s_gd2 paper,
with automatic scaling to account for minibatch sampling.

Parameters
----------
D : distance matrix [n, n]
n_components : output dimensions
learning_rate : initial learning rate
n_iter : number of iterations
init : initial embedding (from classic MDS)
random_state : random state
verbose : verbosity level
pairs_per_iter : number of pairs to sample per iteration
If None, uses n * log(n) pairs per iteration
D : array-like, shape (n_samples, n_samples)
Distance matrix
n_components : int, default=2
Number of dimensions for the output embedding
learning_rate : float, default=0.001
Base learning rate (will be scaled automatically for minibatches)
n_iter : int, default=500
Maximum number of iterations
init : array-like, shape (n_samples, n_components), optional
Initial embedding (e.g., from classical MDS)
random_state : int or RandomState, optional
Random state for reproducibility
verbose : int, default=0
Verbosity level (0=silent, 1=progress, 2=debug)
pairs_per_iter : int, optional
Number of pairs to sample per iteration.
If None, uses n * log(n) pairs per iteration.

Returns
-------
Y : array-like, shape (n_samples, n_components)
Embedded coordinates
"""
if random_state is None:
rng = np.random.RandomState()
Expand Down Expand Up @@ -71,10 +87,40 @@ def sgd_mds(
if verbose > 0:
_logger.log_debug(f"SGD-MDS: sampling {pairs_per_iter} pairs per iteration")

# Exponential learning rate schedule (inspired by s_gd2 paper)
# Reference: https://github.com/jxz12/s_gd2
total_pairs = n_samples * (n_samples - 1) / 2
sampling_ratio = pairs_per_iter / total_pairs

# s_gd2 uses: eta_max = 1/min(w), eta_min = eps/max(w) for weighted MDS
# For uniform weights: eta_max = 1, eta_min = eps (eps=0.01)
#
# Since we're doing minibatch SGD (sampling a fraction of pairs), we compensate
# by scaling the learning rate by sqrt(1/sampling_ratio). This balances:
# - Higher gradient variance from smaller batches (suggests smaller LR)
# - More iterations with partial gradients (allows larger LR)
# The sqrt scaling is a standard heuristic from SGD theory.
batch_scale = np.sqrt(1.0 / sampling_ratio)
eta_max = learning_rate * batch_scale
eta_min = learning_rate * 0.01 * batch_scale # eps = 0.01 as in s_gd2
lambd = np.log(eta_max / eta_min) / max(n_iter - 1, 1)

if verbose > 1:
_logger.log_debug(
f"SGD-MDS setup: n_samples={n_samples}, pairs_per_iter={pairs_per_iter}, "
f"sampling_ratio={sampling_ratio:.6f}, batch_scale={batch_scale:.2f}"
)
_logger.log_debug(
f"Learning rate schedule: eta_max={eta_max:.6f}, eta_min={eta_min:.6f}, "
f"lambda={lambd:.6f}, n_iter={n_iter}"
)

prev_stress = None
stress_history = []

for iteration in range(n_iter):
# Learning rate decay
progress = iteration / max(n_iter - 1, 1)
lr = learning_rate * (1 - progress) ** 0.8
# Exponential decay schedule (s_gd2 style)
lr = eta_max * np.exp(-lambd * iteration)

# Randomly sample pairs (without replacement for efficiency)
# Sample from upper triangle to avoid double-counting
Expand Down Expand Up @@ -112,10 +158,44 @@ def sgd_mds(
# Update
Y = Y - lr * gradients

# Compute stress for convergence checking
stress = np.sum(errors ** 2) / len(errors) # Normalized by number of samples
stress_history.append(stress)

if verbose > 0 and iteration % 100 == 0:
stress = np.sum(errors ** 2)
_logger.log_debug(f"Iter {iteration}: stress={stress:.6f}, lr={lr:.6f}")

if verbose > 1 and (iteration % 100 == 0 or iteration < 5):
_logger.log_debug(
f"Iter {iteration}: stress={stress:.6f}, lr={lr:.6e}, "
f"mean(|grad|)={np.mean(np.abs(gradients)):.6e}, "
f"mean(|Y|)={np.mean(np.abs(Y)):.6e}"
)

# Check for convergence (relative change in stress)
if iteration > 0:
rel_change = abs(stress - prev_stress) / (prev_stress + 1e-10)
if rel_change < 1e-6 and iteration > 50:
if verbose > 0:
_logger.log_info(
f"Converged at iteration {iteration} (rel_change={rel_change:.2e})"
)
break
prev_stress = stress

# Check if converged properly
if len(stress_history) > 10:
# Check if stress is still decreasing significantly in last 10% of iterations
last_10pct = max(1, len(stress_history) // 10)
recent_stress = stress_history[-last_10pct:]
if len(recent_stress) > 1:
stress_trend = (recent_stress[-1] - recent_stress[0]) / (recent_stress[0] + 1e-10)
if abs(stress_trend) > 0.01: # Still changing by more than 1%
_logger.log_warning(
f"SGD-MDS may not have converged: stress changed by {stress_trend*100:.1f}% "
f"in final iterations. Consider increasing n_iter or adjusting learning_rate."
)

# Rescale back to original
if D_max > 0:
Y = Y * D_max
Expand All @@ -130,10 +210,32 @@ def sgd_mds_metric(
random_state=None,
verbose=0,
):
"""Auto-tuned SGD-MDS with optimal parameters for different data sizes"""
"""Auto-tuned SGD-MDS with optimal parameters for different data sizes

Automatically selects the number of iterations and pairs per iteration
based on the dataset size.

Parameters
----------
D : array-like, shape (n_samples, n_samples)
Distance matrix
n_components : int, default=2
Number of dimensions for the output embedding
init : array-like, shape (n_samples, n_components), optional
Initial embedding (e.g., from classical MDS)
random_state : int or RandomState, optional
Random state for reproducibility
verbose : int, default=0
Verbosity level (0=silent, 1=progress, 2=debug)

Returns
-------
Y : array-like, shape (n_samples, n_components)
Embedded coordinates
"""
n_samples = D.shape[0]

# Auto-tune: more iterations for larger n
# Auto-tune: more iterations for larger datasets
if n_samples < 1000:
n_iter = 300
pairs_per_iter = n_samples * n_samples // 10 # 10% of all pairs
Expand Down
Loading