@@ -123,13 +123,6 @@ def solve_weights(
123
123
)
124
124
target_timestep = timesteps [- 1 ] # Use last timestep as target
125
125
126
- # Print debug info
127
- print (f"\n Solve weights debug info:" )
128
- print (f" Source gammas: { gammas .tolist ()} " )
129
- print (f" Target gamma: { target_gamma .item ()} " )
130
- print (f" Timesteps range: { timesteps [0 ].item ()} to { timesteps [- 1 ].item ()} " )
131
- print (f" Calculation dtype: { calculation_dtype } " )
132
-
133
126
# Pre-allocate tensor in calculation dtype
134
127
p_dot_p_matrix = torch .empty (
135
128
(len (gammas ), len (gammas )), dtype = calculation_dtype , device = gammas .device
@@ -142,14 +135,6 @@ def solve_weights(
142
135
timesteps [i ], gammas [i ], timesteps [j ], gammas [j ]
143
136
)
144
137
145
- # Print matrix properties
146
- print (f"\n Matrix properties:" )
147
- print (f" A shape: { p_dot_p_matrix .shape } " )
148
- print (f" A condition number: { torch .linalg .cond (p_dot_p_matrix ).item ():.2e} " )
149
- print (
150
- f" A min/max: { p_dot_p_matrix .min ().item ():.2e} /{ p_dot_p_matrix .max ().item ():.2e} "
151
- )
152
-
153
138
# Compute target vector
154
139
target_vector = torch .tensor (
155
140
[
@@ -160,68 +145,43 @@ def solve_weights(
160
145
device = gammas .device ,
161
146
)
162
147
163
- print (f" b shape: { target_vector .shape } " )
164
- print (
165
- f" b min/max: { target_vector .min ().item ():.2e} /{ target_vector .max ().item ():.2e} "
166
- )
167
-
168
148
# Use target_sigma_rel directly if provided, otherwise compute from gamma
169
149
if target_sigma_rel is None :
170
150
target_sigma_rel = float (
171
151
np .sqrt ((target_gamma + 1 ) / ((target_gamma + 2 ) * (target_gamma + 3 )))
172
152
)
173
- print (f"\n Solver selection:" )
174
- print (f" Target sigma_rel: { target_sigma_rel :.6f} " )
175
- print (f" Using { 'original' if target_sigma_rel <= 0.28 else 'stable' } solver" )
176
153
177
154
if target_sigma_rel <= 0.28 :
178
155
# Original solver for small sigma_rel values
179
156
try :
180
- print (" Attempting direct solve..." )
181
157
weights = torch .linalg .solve (p_dot_p_matrix , target_vector )
182
- print (" Direct solve succeeded" )
183
- print (f" Weights sum: { weights .sum ().item ():.6f} " )
184
- print (
185
- f" Weights min/max: { weights .min ().item ():.6f} /{ weights .max ().item ():.6f} "
186
- )
187
158
return weights
188
159
except RuntimeError as e :
189
- print (f" Direct solve failed: { str (e )} " )
190
- print (" Falling back to SVD..." )
160
+ print (f"Direct solve failed: { str (e )} " )
161
+ print ("Falling back to SVD..." )
191
162
# Original fallback
192
163
U , S , Vh = torch .linalg .svd (p_dot_p_matrix )
193
164
S_inv = torch .where (S > 0 , 1.0 / S , torch .zeros_like (S ))
194
165
weights = Vh .t () @ (
195
166
S_inv .unsqueeze (- 1 ) * (U .t () @ target_vector .unsqueeze (- 1 ))
196
167
)
197
168
weights = weights .squeeze ()
198
- print (" SVD solve succeeded" )
199
- print (f" Weights sum: { weights .sum ().item ():.6f} " )
200
- print (
201
- f" Weights min/max: { weights .min ().item ():.6f} /{ weights .max ().item ():.6f} "
202
- )
203
169
return weights
204
170
else :
205
171
# Use more robust solver for larger sigma_rel values
206
172
# Add moderate regularization for stability
207
173
p_dot_p_matrix .diagonal ().add_ (1e-6 )
208
174
209
175
try :
210
- print (" Attempting direct solve with regularization..." )
211
176
weights = torch .linalg .solve (p_dot_p_matrix , target_vector )
212
177
if torch .isfinite (weights ).all () and weights .abs ().max () < 1e3 :
213
- print (" Direct solve succeeded" )
214
- print (f" Weights sum: { weights .sum ().item ():.6f} " )
215
- print (
216
- f" Weights min/max: { weights .min ().item ():.6f} /{ weights .max ().item ():.6f} "
217
- )
218
178
return weights
219
- print (" Direct solve produced unstable weights" )
179
+ print ("Direct solve produced unstable weights" )
220
180
except RuntimeError as e :
221
- print (f" Direct solve failed: { str (e )} " )
181
+ print (f"Direct solve failed: { str (e )} " )
222
182
223
183
try :
224
- print (" Attempting SVD with stronger filtering..." )
184
+ print ("Attempting SVD with stronger filtering..." )
225
185
U , S , Vh = torch .linalg .svd (p_dot_p_matrix )
226
186
rcond = 1e-8
227
187
threshold = rcond * S .max ()
@@ -231,28 +191,18 @@ def solve_weights(
231
191
)
232
192
weights = weights .squeeze ()
233
193
if torch .isfinite (weights ).all () and weights .abs ().max () < 1e3 :
234
- print (" SVD solve succeeded" )
235
- print (f" Weights sum: { weights .sum ().item ():.6f} " )
236
- print (
237
- f" Weights min/max: { weights .min ().item ():.6f} /{ weights .max ().item ():.6f} "
238
- )
239
194
return weights
240
195
print (" SVD solve produced unstable weights" )
241
196
except RuntimeError as e :
242
197
print (f" SVD solve failed: { str (e )} " )
243
198
244
- print (" Using final fallback: damped least squares..." )
199
+ print ("Using final fallback: damped least squares..." )
245
200
reg_matrix = (
246
201
p_dot_p_matrix
247
202
+ torch .eye (len (gammas ), dtype = calculation_dtype , device = gammas .device )
248
203
* 1e-4
249
204
)
250
205
weights = torch .linalg .solve (reg_matrix , target_vector )
251
- print (" Damped least squares succeeded" )
252
- print (f" Weights sum: { weights .sum ().item ():.6f} " )
253
- print (
254
- f" Weights min/max: { weights .min ().item ():.6f} /{ weights .max ().item ():.6f} "
255
- )
256
206
return weights
257
207
258
208
0 commit comments