-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlosses.py
More file actions
59 lines (49 loc) · 2.07 KB
/
losses.py
File metadata and controls
59 lines (49 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn.functional as F
def loss_recon(X_synthetic: torch.Tensor, X_current: torch.Tensor) -> torch.Tensor:
"""
Reconstruction Loss (MSE)
Ensures structural fidelity between synthetic & real current mammograms.
"""
return F.mse_loss(X_synthetic, X_current)
def loss_kl(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""
KL Divergence Loss
Forces the latent space to follow N(0, I).
Formula: L_KL = -1/2 * Σ(1 + logσ² - μ² - σ²)
"""
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
def loss_tumor_bce(T_hat: torch.Tensor, M_gt: torch.Tensor) -> torch.Tensor:
"""
Tumor Loss (Binary Cross Entropy)
Supervises tumor region predictions when ground truth masks are available.
"""
return F.binary_cross_entropy(T_hat, M_gt)
def gan_terms(D, X_prior, X_current, X_synthetic, epsilon: float):
"""
GAN Loss (Adversarial)
Discriminator tries to distinguish real vs synthetic triplets.
Generator tries to fool the discriminator.
Inputs:
D -> Discriminator
X_prior -> Prior mammogram
X_current -> Current mammogram
X_synthetic-> Generated synthetic mammogram
epsilon -> Stability constant for log terms
Returns:
y_real -> Discriminator predictions on real triplets
y_fake -> Discriminator predictions on fake triplets
L_GAN -> Generator's adversarial loss
L_D -> Discriminator's loss
"""
# Real triplet: [X_prior, X_current, X_current]
x_real = torch.cat([X_prior, X_current, X_current], dim=1)
y_real = D(x_real).clamp(epsilon, 1 - epsilon)
# Fake triplet: [X_prior, X_current, X_synthetic]
x_fake = torch.cat([X_prior, X_current, X_synthetic], dim=1)
y_fake = D(x_fake).clamp(epsilon, 1 - epsilon)
# Adversarial loss
L_GAN = (torch.log(y_real) + torch.log(1.0 - y_fake)).mean()
# Discriminator's loss
L_D = -L_GAN
return y_real, y_fake, L_GAN, L_D