Skip to content

Commit 5bacfd4

Browse files
committed
improve: less debug logging
1 parent 8230348 commit 5bacfd4

File tree

1 file changed

+6
-56
lines changed

1 file changed

+6
-56
lines changed

posthoc_ema/utils.py

Lines changed: 6 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,6 @@ def solve_weights(
123123
)
124124
target_timestep = timesteps[-1] # Use last timestep as target
125125

126-
# Print debug info
127-
print(f"\nSolve weights debug info:")
128-
print(f" Source gammas: {gammas.tolist()}")
129-
print(f" Target gamma: {target_gamma.item()}")
130-
print(f" Timesteps range: {timesteps[0].item()} to {timesteps[-1].item()}")
131-
print(f" Calculation dtype: {calculation_dtype}")
132-
133126
# Pre-allocate tensor in calculation dtype
134127
p_dot_p_matrix = torch.empty(
135128
(len(gammas), len(gammas)), dtype=calculation_dtype, device=gammas.device
@@ -142,14 +135,6 @@ def solve_weights(
142135
timesteps[i], gammas[i], timesteps[j], gammas[j]
143136
)
144137

145-
# Print matrix properties
146-
print(f"\nMatrix properties:")
147-
print(f" A shape: {p_dot_p_matrix.shape}")
148-
print(f" A condition number: {torch.linalg.cond(p_dot_p_matrix).item():.2e}")
149-
print(
150-
f" A min/max: {p_dot_p_matrix.min().item():.2e}/{p_dot_p_matrix.max().item():.2e}"
151-
)
152-
153138
# Compute target vector
154139
target_vector = torch.tensor(
155140
[
@@ -160,68 +145,43 @@ def solve_weights(
160145
device=gammas.device,
161146
)
162147

163-
print(f" b shape: {target_vector.shape}")
164-
print(
165-
f" b min/max: {target_vector.min().item():.2e}/{target_vector.max().item():.2e}"
166-
)
167-
168148
# Use target_sigma_rel directly if provided, otherwise compute from gamma
169149
if target_sigma_rel is None:
170150
target_sigma_rel = float(
171151
np.sqrt((target_gamma + 1) / ((target_gamma + 2) * (target_gamma + 3)))
172152
)
173-
print(f"\nSolver selection:")
174-
print(f" Target sigma_rel: {target_sigma_rel:.6f}")
175-
print(f" Using {'original' if target_sigma_rel <= 0.28 else 'stable'} solver")
176153

177154
if target_sigma_rel <= 0.28:
178155
# Original solver for small sigma_rel values
179156
try:
180-
print(" Attempting direct solve...")
181157
weights = torch.linalg.solve(p_dot_p_matrix, target_vector)
182-
print(" Direct solve succeeded")
183-
print(f" Weights sum: {weights.sum().item():.6f}")
184-
print(
185-
f" Weights min/max: {weights.min().item():.6f}/{weights.max().item():.6f}"
186-
)
187158
return weights
188159
except RuntimeError as e:
189-
print(f" Direct solve failed: {str(e)}")
190-
print(" Falling back to SVD...")
160+
print(f"Direct solve failed: {str(e)}")
161+
print("Falling back to SVD...")
191162
# Original fallback
192163
U, S, Vh = torch.linalg.svd(p_dot_p_matrix)
193164
S_inv = torch.where(S > 0, 1.0 / S, torch.zeros_like(S))
194165
weights = Vh.t() @ (
195166
S_inv.unsqueeze(-1) * (U.t() @ target_vector.unsqueeze(-1))
196167
)
197168
weights = weights.squeeze()
198-
print(" SVD solve succeeded")
199-
print(f" Weights sum: {weights.sum().item():.6f}")
200-
print(
201-
f" Weights min/max: {weights.min().item():.6f}/{weights.max().item():.6f}"
202-
)
203169
return weights
204170
else:
205171
# Use more robust solver for larger sigma_rel values
206172
# Add moderate regularization for stability
207173
p_dot_p_matrix.diagonal().add_(1e-6)
208174

209175
try:
210-
print(" Attempting direct solve with regularization...")
211176
weights = torch.linalg.solve(p_dot_p_matrix, target_vector)
212177
if torch.isfinite(weights).all() and weights.abs().max() < 1e3:
213-
print(" Direct solve succeeded")
214-
print(f" Weights sum: {weights.sum().item():.6f}")
215-
print(
216-
f" Weights min/max: {weights.min().item():.6f}/{weights.max().item():.6f}"
217-
)
218178
return weights
219-
print(" Direct solve produced unstable weights")
179+
print("Direct solve produced unstable weights")
220180
except RuntimeError as e:
221-
print(f" Direct solve failed: {str(e)}")
181+
print(f"Direct solve failed: {str(e)}")
222182

223183
try:
224-
print(" Attempting SVD with stronger filtering...")
184+
print("Attempting SVD with stronger filtering...")
225185
U, S, Vh = torch.linalg.svd(p_dot_p_matrix)
226186
rcond = 1e-8
227187
threshold = rcond * S.max()
@@ -231,28 +191,18 @@ def solve_weights(
231191
)
232192
weights = weights.squeeze()
233193
if torch.isfinite(weights).all() and weights.abs().max() < 1e3:
234-
print(" SVD solve succeeded")
235-
print(f" Weights sum: {weights.sum().item():.6f}")
236-
print(
237-
f" Weights min/max: {weights.min().item():.6f}/{weights.max().item():.6f}"
238-
)
239194
return weights
240195
print(" SVD solve produced unstable weights")
241196
except RuntimeError as e:
242197
print(f" SVD solve failed: {str(e)}")
243198

244-
print(" Using final fallback: damped least squares...")
199+
print("Using final fallback: damped least squares...")
245200
reg_matrix = (
246201
p_dot_p_matrix
247202
+ torch.eye(len(gammas), dtype=calculation_dtype, device=gammas.device)
248203
* 1e-4
249204
)
250205
weights = torch.linalg.solve(reg_matrix, target_vector)
251-
print(" Damped least squares succeeded")
252-
print(f" Weights sum: {weights.sum().item():.6f}")
253-
print(
254-
f" Weights min/max: {weights.min().item():.6f}/{weights.max().item():.6f}"
255-
)
256206
return weights
257207

258208

0 commit comments

Comments
 (0)