@@ -25,17 +25,33 @@ def sgd_mds(
2525 Randomly samples pairs at each iteration - simple and effective!
2626 This approach is 7-10x faster than SMACOF while maintaining excellent quality.
2727
28+ Uses an exponential learning rate schedule inspired by the s_gd2 paper,
29+ with automatic scaling to account for minibatch sampling.
30+
2831 Parameters
2932 ----------
30- D : distance matrix [n, n]
31- n_components : output dimensions
32- learning_rate : initial learning rate
33- n_iter : number of iterations
34- init : initial embedding (from classic MDS)
35- random_state : random state
36- verbose : verbosity level
37- pairs_per_iter : number of pairs to sample per iteration
38- If None, uses n * log(n) pairs per iteration
33+ D : array-like, shape (n_samples, n_samples)
34+ Distance matrix
35+ n_components : int, default=2
36+ Number of dimensions for the output embedding
37+ learning_rate : float, default=0.001
38+ Base learning rate (will be scaled automatically for minibatches)
39+ n_iter : int, default=500
40+ Maximum number of iterations
41+ init : array-like, shape (n_samples, n_components), optional
42+ Initial embedding (e.g., from classical MDS)
43+ random_state : int or RandomState, optional
44+ Random state for reproducibility
45+ verbose : int, default=0
46+ Verbosity level (0=silent, 1=progress, 2=debug)
47+ pairs_per_iter : int, optional
48+ Number of pairs to sample per iteration.
49+ If None, uses n * log(n) pairs per iteration.
50+
51+ Returns
52+ -------
53+ Y : array-like, shape (n_samples, n_components)
54+ Embedded coordinates
3955 """
4056 if random_state is None :
4157 rng = np .random .RandomState ()
@@ -71,10 +87,40 @@ def sgd_mds(
7187 if verbose > 0 :
7288 _logger .log_debug (f"SGD-MDS: sampling { pairs_per_iter } pairs per iteration" )
7389
90+ # Exponential learning rate schedule (inspired by s_gd2 paper)
91+ # Reference: https://github.com/jxz12/s_gd2
92+ total_pairs = n_samples * (n_samples - 1 ) / 2
93+ sampling_ratio = pairs_per_iter / total_pairs
94+
95+ # s_gd2 uses: eta_max = 1/min(w), eta_min = eps/max(w) for weighted MDS
96+ # For uniform weights: eta_max = 1, eta_min = eps (eps=0.01)
97+ #
98+ # Since we're doing minibatch SGD (sampling a fraction of pairs), we compensate
99+ # by scaling the learning rate by sqrt(1/sampling_ratio). This balances:
100+ # - Higher gradient variance from smaller batches (suggests smaller LR)
101+ # - More iterations with partial gradients (allows larger LR)
102+ # The sqrt scaling is a standard heuristic from SGD theory.
103+ batch_scale = np .sqrt (1.0 / sampling_ratio )
104+ eta_max = learning_rate * batch_scale
105+ eta_min = learning_rate * 0.01 * batch_scale # eps = 0.01 as in s_gd2
106+ lambd = np .log (eta_max / eta_min ) / max (n_iter - 1 , 1 )
107+
108+ if verbose > 1 :
109+ _logger .log_debug (
110+ f"SGD-MDS setup: n_samples={ n_samples } , pairs_per_iter={ pairs_per_iter } , "
111+ f"sampling_ratio={ sampling_ratio :.6f} , batch_scale={ batch_scale :.2f} "
112+ )
113+ _logger .log_debug (
114+ f"Learning rate schedule: eta_max={ eta_max :.6f} , eta_min={ eta_min :.6f} , "
115+ f"lambda={ lambd :.6f} , n_iter={ n_iter } "
116+ )
117+
118+ prev_stress = None
119+ stress_history = []
120+
74121 for iteration in range (n_iter ):
75- # Learning rate decay
76- progress = iteration / max (n_iter - 1 , 1 )
77- lr = learning_rate * (1 - progress ) ** 0.8
122+ # Exponential decay schedule (s_gd2 style)
123+ lr = eta_max * np .exp (- lambd * iteration )
78124
79125 # Randomly sample pairs (without replacement for efficiency)
80126 # Sample from upper triangle to avoid double-counting
@@ -112,10 +158,44 @@ def sgd_mds(
112158 # Update
113159 Y = Y - lr * gradients
114160
161+ # Compute stress for convergence checking
162+ stress = np .sum (errors ** 2 ) / len (errors ) # Normalized by number of samples
163+ stress_history .append (stress )
164+
115165 if verbose > 0 and iteration % 100 == 0 :
116- stress = np .sum (errors ** 2 )
117166 _logger .log_debug (f"Iter { iteration } : stress={ stress :.6f} , lr={ lr :.6f} " )
118167
168+ if verbose > 1 and (iteration % 100 == 0 or iteration < 5 ):
169+ _logger .log_debug (
170+ f"Iter { iteration } : stress={ stress :.6f} , lr={ lr :.6e} , "
171+ f"mean(|grad|)={ np .mean (np .abs (gradients )):.6e} , "
172+ f"mean(|Y|)={ np .mean (np .abs (Y )):.6e} "
173+ )
174+
175+ # Check for convergence (relative change in stress)
176+ if iteration > 0 :
177+ rel_change = abs (stress - prev_stress ) / (prev_stress + 1e-10 )
178+ if rel_change < 1e-6 and iteration > 50 :
179+ if verbose > 0 :
180+ _logger .log_info (
181+ f"Converged at iteration { iteration } (rel_change={ rel_change :.2e} )"
182+ )
183+ break
184+ prev_stress = stress
185+
186+ # Check if converged properly
187+ if len (stress_history ) > 10 :
188+ # Check if stress is still decreasing significantly in last 10% of iterations
189+ last_10pct = max (1 , len (stress_history ) // 10 )
190+ recent_stress = stress_history [- last_10pct :]
191+ if len (recent_stress ) > 1 :
192+ stress_trend = (recent_stress [- 1 ] - recent_stress [0 ]) / (recent_stress [0 ] + 1e-10 )
193+ if abs (stress_trend ) > 0.01 : # Still changing by more than 1%
194+ _logger .log_warning (
195+ f"SGD-MDS may not have converged: stress changed by { stress_trend * 100 :.1f} % "
196+ f"in final iterations. Consider increasing n_iter or adjusting learning_rate."
197+ )
198+
119199 # Rescale back to original
120200 if D_max > 0 :
121201 Y = Y * D_max
@@ -130,10 +210,32 @@ def sgd_mds_metric(
130210 random_state = None ,
131211 verbose = 0 ,
132212):
133- """Auto-tuned SGD-MDS with optimal parameters for different data sizes"""
213+ """Auto-tuned SGD-MDS with optimal parameters for different data sizes
214+
215+ Automatically selects the number of iterations and pairs per iteration
216+ based on the dataset size.
217+
218+ Parameters
219+ ----------
220+ D : array-like, shape (n_samples, n_samples)
221+ Distance matrix
222+ n_components : int, default=2
223+ Number of dimensions for the output embedding
224+ init : array-like, shape (n_samples, n_components), optional
225+ Initial embedding (e.g., from classical MDS)
226+ random_state : int or RandomState, optional
227+ Random state for reproducibility
228+ verbose : int, default=0
229+ Verbosity level (0=silent, 1=progress, 2=debug)
230+
231+ Returns
232+ -------
233+ Y : array-like, shape (n_samples, n_components)
234+ Embedded coordinates
235+ """
134236 n_samples = D .shape [0 ]
135237
136- # Auto-tune: more iterations for larger n
238+ # Auto-tune: more iterations for larger datasets
137239 if n_samples < 1000 :
138240 n_iter = 300
139241 pairs_per_iter = n_samples * n_samples // 10 # 10% of all pairs
0 commit comments