Skip to content

Commit 7178042

Browse files
committed
added logging
1 parent 506b9ca commit 7178042

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

examples/research_projects/lpl/lpl_loss.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
self.pow_law = pow_law
6262
self.norm_type = norm_type.lower()
6363
self.outlier_mask = remove_outliers
64+
self.last_feature_stats = [] # Store feature statistics for logging
6465

6566
assert feature_type in ["feature", "image"]
6667
self.feature_type = feature_type
@@ -132,15 +133,29 @@ def get_loss(self, input, target, get_hist=False):
132133
inp_f = self.get_features(self.shift + input / self.scale)
133134
tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True)
134135
losses = []
136+
self.last_feature_stats = [] # Reset feature stats
135137

136138
for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)):
137139
my = torch.ones_like(y).bool()
140+
outlier_ratio = 0.0
141+
138142
if self.outlier_mask:
139143
with torch.no_grad():
140144
if i == 2:
141145
my, y = remove_outliers(y, down_f=2)
146+
outlier_ratio = 1.0 - my.float().mean().item()
142147
elif i in [3, 4, 5]:
143148
my, y = remove_outliers(y, down_f=1)
149+
outlier_ratio = 1.0 - my.float().mean().item()
150+
151+
# Store feature statistics before normalization
152+
with torch.no_grad():
153+
stats = {
154+
'mean': y.mean().item(),
155+
'std': y.std().item(),
156+
'outlier_ratio': outlier_ratio,
157+
}
158+
self.last_feature_stats.append(stats)
144159

145160
# normalize feature tensors
146161
if self.norm_type == "default":

examples/research_projects/lpl/lpl_sdxl.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,14 +1354,65 @@ def compute_time_ids(original_size, crops_coords_top_left):
13541354
progress_bar.update(1)
13551355
global_step += 1
13561356

1357+
# Enhanced logging for LPL metrics
13571358
log_data = {
13581359
"train_loss": train_loss,
13591360
"diffusion_loss": loss.item(),
1361+
"learning_rate": lr_scheduler.get_last_lr()[0],
13601362
}
1361-
if args.use_lpl and lpl_loss_value.item() > 0:
1362-
log_data["lpl_loss"] = lpl_loss_value.item()
1363+
1364+
if args.use_lpl and lpl_fn is not None and global_step >= args.lpl_start:
1365+
if lpl_mask.any():
1366+
# LPL application statistics
1367+
log_data.update({
1368+
"lpl/loss": lpl_loss_value.item(),
1369+
"lpl/num_samples": lpl_mask.sum().item(),
1370+
"lpl/application_ratio": lpl_mask.float().mean().item(),
1371+
"lpl/weight": args.lpl_weight,
1372+
"lpl/weighted_loss": (args.lpl_weight * lpl_loss_value).item(),
1373+
})
1374+
1375+
# SNR statistics for LPL-applied samples
1376+
if args.snr_gamma is not None:
1377+
snr_values = snr[masked_indices]
1378+
log_data.update({
1379+
"lpl/snr_mean": snr_values.mean().item(),
1380+
"lpl/snr_std": snr_values.std().item(),
1381+
"lpl/snr_min": snr_values.min().item(),
1382+
"lpl/snr_max": snr_values.max().item(),
1383+
})
1384+
1385+
# Feature statistics if available
1386+
if hasattr(lpl_fn, 'last_feature_stats'):
1387+
for layer_idx, stats in enumerate(lpl_fn.last_feature_stats):
1388+
log_data.update({
1389+
f"lpl/features/layer_{layer_idx}/mean": stats['mean'].item(),
1390+
f"lpl/features/layer_{layer_idx}/std": stats['std'].item(),
1391+
f"lpl/features/layer_{layer_idx}/outlier_ratio": stats.get('outlier_ratio', 0.0),
1392+
})
1393+
1394+
# Memory usage if available
1395+
if torch.cuda.is_available():
1396+
log_data.update({
1397+
"lpl/memory/allocated": torch.cuda.memory_allocated() / 1024**2, # MB
1398+
"lpl/memory/reserved": torch.cuda.memory_reserved() / 1024**2, # MB
1399+
})
1400+
1401+
# Log to accelerator
13631402
accelerator.log(log_data, step=global_step)
13641403

1404+
# Update progress bar with more metrics
1405+
progress_bar_logs = {
1406+
"loss": loss.detach().item(),
1407+
"lr": lr_scheduler.get_last_lr()[0],
1408+
}
1409+
if args.use_lpl and lpl_loss_value.item() > 0:
1410+
progress_bar_logs.update({
1411+
"lpl": lpl_loss_value.item(),
1412+
"lpl_ratio": lpl_mask.float().mean().item() if lpl_mask.any() else 0.0,
1413+
})
1414+
progress_bar.set_postfix(**progress_bar_logs)
1415+
13651416
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
13661417
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
13671418
if global_step % args.checkpointing_steps == 0:

0 commit comments

Comments
 (0)