Skip to content

Commit e920469

Browse files
authored
Merge pull request #164 from KrishnaswamyLab/mila_local_phate
changed lr schedule for SGD MDS
2 parents 87f9b60 + e72c4c0 commit e920469

File tree

1 file changed

+117
-15
lines changed

1 file changed

+117
-15
lines changed

Python/phate/sgd_mds.py

Lines changed: 117 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,33 @@ def sgd_mds(
2525
Randomly samples pairs at each iteration - simple and effective!
2626
This approach is 7-10x faster than SMACOF while maintaining excellent quality.
2727
28+
Uses an exponential learning rate schedule inspired by the s_gd2 paper,
29+
with automatic scaling to account for minibatch sampling.
30+
2831
Parameters
2932
----------
30-
D : distance matrix [n, n]
31-
n_components : output dimensions
32-
learning_rate : initial learning rate
33-
n_iter : number of iterations
34-
init : initial embedding (from classic MDS)
35-
random_state : random state
36-
verbose : verbosity level
37-
pairs_per_iter : number of pairs to sample per iteration
38-
If None, uses n * log(n) pairs per iteration
33+
D : array-like, shape (n_samples, n_samples)
34+
Distance matrix
35+
n_components : int, default=2
36+
Number of dimensions for the output embedding
37+
learning_rate : float, default=0.001
38+
Base learning rate (will be scaled automatically for minibatches)
39+
n_iter : int, default=500
40+
Maximum number of iterations
41+
init : array-like, shape (n_samples, n_components), optional
42+
Initial embedding (e.g., from classical MDS)
43+
random_state : int or RandomState, optional
44+
Random state for reproducibility
45+
verbose : int, default=0
46+
Verbosity level (0=silent, 1=progress, 2=debug)
47+
pairs_per_iter : int, optional
48+
Number of pairs to sample per iteration.
49+
If None, uses n * log(n) pairs per iteration.
50+
51+
Returns
52+
-------
53+
Y : array-like, shape (n_samples, n_components)
54+
Embedded coordinates
3955
"""
4056
if random_state is None:
4157
rng = np.random.RandomState()
@@ -71,10 +87,40 @@ def sgd_mds(
7187
if verbose > 0:
7288
_logger.log_debug(f"SGD-MDS: sampling {pairs_per_iter} pairs per iteration")
7389

90+
# Exponential learning rate schedule (inspired by s_gd2 paper)
91+
# Reference: https://github.com/jxz12/s_gd2
92+
total_pairs = n_samples * (n_samples - 1) / 2
93+
sampling_ratio = pairs_per_iter / total_pairs
94+
95+
# s_gd2 uses: eta_max = 1/min(w), eta_min = eps/max(w) for weighted MDS
96+
# For uniform weights: eta_max = 1, eta_min = eps (eps=0.01)
97+
#
98+
# Since we're doing minibatch SGD (sampling a fraction of pairs), we compensate
99+
# by scaling the learning rate by sqrt(1/sampling_ratio). This balances:
100+
# - Higher gradient variance from smaller batches (suggests smaller LR)
101+
# - More iterations with partial gradients (allows larger LR)
102+
# The sqrt scaling is a standard heuristic from SGD theory.
103+
batch_scale = np.sqrt(1.0 / sampling_ratio)
104+
eta_max = learning_rate * batch_scale
105+
eta_min = learning_rate * 0.01 * batch_scale # eps = 0.01 as in s_gd2
106+
lambd = np.log(eta_max / eta_min) / max(n_iter - 1, 1)
107+
108+
if verbose > 1:
109+
_logger.log_debug(
110+
f"SGD-MDS setup: n_samples={n_samples}, pairs_per_iter={pairs_per_iter}, "
111+
f"sampling_ratio={sampling_ratio:.6f}, batch_scale={batch_scale:.2f}"
112+
)
113+
_logger.log_debug(
114+
f"Learning rate schedule: eta_max={eta_max:.6f}, eta_min={eta_min:.6f}, "
115+
f"lambda={lambd:.6f}, n_iter={n_iter}"
116+
)
117+
118+
prev_stress = None
119+
stress_history = []
120+
74121
for iteration in range(n_iter):
75-
# Learning rate decay
76-
progress = iteration / max(n_iter - 1, 1)
77-
lr = learning_rate * (1 - progress) ** 0.8
122+
# Exponential decay schedule (s_gd2 style)
123+
lr = eta_max * np.exp(-lambd * iteration)
78124

79125
# Randomly sample pairs (without replacement for efficiency)
80126
# Sample from upper triangle to avoid double-counting
@@ -112,10 +158,44 @@ def sgd_mds(
112158
# Update
113159
Y = Y - lr * gradients
114160

161+
# Compute stress for convergence checking
162+
stress = np.sum(errors ** 2) / len(errors) # Normalized by number of samples
163+
stress_history.append(stress)
164+
115165
if verbose > 0 and iteration % 100 == 0:
116-
stress = np.sum(errors ** 2)
117166
_logger.log_debug(f"Iter {iteration}: stress={stress:.6f}, lr={lr:.6f}")
118167

168+
if verbose > 1 and (iteration % 100 == 0 or iteration < 5):
169+
_logger.log_debug(
170+
f"Iter {iteration}: stress={stress:.6f}, lr={lr:.6e}, "
171+
f"mean(|grad|)={np.mean(np.abs(gradients)):.6e}, "
172+
f"mean(|Y|)={np.mean(np.abs(Y)):.6e}"
173+
)
174+
175+
# Check for convergence (relative change in stress)
176+
if iteration > 0:
177+
rel_change = abs(stress - prev_stress) / (prev_stress + 1e-10)
178+
if rel_change < 1e-6 and iteration > 50:
179+
if verbose > 0:
180+
_logger.log_info(
181+
f"Converged at iteration {iteration} (rel_change={rel_change:.2e})"
182+
)
183+
break
184+
prev_stress = stress
185+
186+
# Check if converged properly
187+
if len(stress_history) > 10:
188+
# Check if stress is still decreasing significantly in last 10% of iterations
189+
last_10pct = max(1, len(stress_history) // 10)
190+
recent_stress = stress_history[-last_10pct:]
191+
if len(recent_stress) > 1:
192+
stress_trend = (recent_stress[-1] - recent_stress[0]) / (recent_stress[0] + 1e-10)
193+
if abs(stress_trend) > 0.01: # Still changing by more than 1%
194+
_logger.log_warning(
195+
f"SGD-MDS may not have converged: stress changed by {stress_trend*100:.1f}% "
196+
f"in final iterations. Consider increasing n_iter or adjusting learning_rate."
197+
)
198+
119199
# Rescale back to original
120200
if D_max > 0:
121201
Y = Y * D_max
@@ -130,10 +210,32 @@ def sgd_mds_metric(
130210
random_state=None,
131211
verbose=0,
132212
):
133-
"""Auto-tuned SGD-MDS with optimal parameters for different data sizes"""
213+
"""Auto-tuned SGD-MDS with optimal parameters for different data sizes
214+
215+
Automatically selects the number of iterations and pairs per iteration
216+
based on the dataset size.
217+
218+
Parameters
219+
----------
220+
D : array-like, shape (n_samples, n_samples)
221+
Distance matrix
222+
n_components : int, default=2
223+
Number of dimensions for the output embedding
224+
init : array-like, shape (n_samples, n_components), optional
225+
Initial embedding (e.g., from classical MDS)
226+
random_state : int or RandomState, optional
227+
Random state for reproducibility
228+
verbose : int, default=0
229+
Verbosity level (0=silent, 1=progress, 2=debug)
230+
231+
Returns
232+
-------
233+
Y : array-like, shape (n_samples, n_components)
234+
Embedded coordinates
235+
"""
134236
n_samples = D.shape[0]
135237

136-
# Auto-tune: more iterations for larger n
238+
# Auto-tune: more iterations for larger datasets
137239
if n_samples < 1000:
138240
n_iter = 300
139241
pairs_per_iter = n_samples * n_samples // 10 # 10% of all pairs

0 commit comments

Comments
 (0)