-
Notifications
You must be signed in to change notification settings - Fork 46
Expand file tree
/
Copy pathloss.py
More file actions
38 lines (30 loc) · 820 Bytes
/
loss.py
File metadata and controls
38 lines (30 loc) · 820 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
Module to calculate different losses, log the losses and save the rgb image
"""
import tensorflow as tf
import matplotlib.pyplot as plt
def KL_loss(y_true, y_pred):
"""
Calculate Kullback–Leibler divergence
"""
mean = y_pred[:, :128]
logsigma = y_pred[:, :128]
loss = -logsigma + .5 * ( tf.math.square(mean) -1 + tf.math.exp(2. * logsigma))
loss = tf.math.reduce_mean(loss)
return loss
def custom_generator_loss(y_true, y_pred):
"""
Calculate binary cross entropy loss
"""
return tf.keras.metrics.binary_crossentropy(y_true, y_pred)
def save_rgb_img(img, path):
"""
Save an rgb image
"""
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.imshow(img)
ax.axis("off")
ax.set_title("Image")
plt.savefig(path)
plt.close()