Skip to content

Commit bfffb4c

Browse files
committed
core\refac: #98 debug and fix loss
- debug and fix tgce loss variants
1 parent 497dcf8 commit bfffb4c

File tree

3 files changed

+67
-9
lines changed

3 files changed

+67
-9
lines changed

core/seg_tgce/experiments/histology/features.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,18 @@
1313

1414
from ..utils import handle_training
1515

16-
TARGET_SHAPE = (256, 256)
17-
BATCH_SIZE = 32
16+
TARGET_SHAPE = (128, 128)
17+
BATCH_SIZE = 4
1818
TRAIN_EPOCHS = 20
1919
TUNER_EPOCHS = 1
2020

2121
DEFAULT_HPARAMS = {
2222
"initial_learning_rate": 1e-3,
2323
"q": 0.5,
2424
"noise_tolerance": 0.5,
25+
"a": 0.5,
26+
"b": 0.5,
27+
"c": 1.0,
2528
"lambda_reg_weight": 0.1,
2629
"lambda_entropy_weight": 0.1,
2730
"lambda_sum_weight": 0.1,
@@ -49,12 +52,18 @@ def build_model(hp=None):
4952
"lambda_sum_weight": hp.Float(
5053
"lambda_sum_weight", min_value=0.01, max_value=0.5, step=0.01
5154
),
55+
"a": hp.Float("a", min_value=0.0, max_value=1.0, step=0.1),
56+
"b": hp.Float("b", min_value=0.0, max_value=1.0, step=0.1),
57+
"c": hp.Float("c", min_value=0.0, max_value=1.0, step=0.1),
5258
}
5359

5460
return build_features_model_from_hparams(
5561
learning_rate=params["initial_learning_rate"],
5662
q=params["q"],
5763
noise_tolerance=params["noise_tolerance"],
64+
a=params["a"],
65+
b=params["b"],
66+
c=params["c"],
5867
lambda_reg_weight=params["lambda_reg_weight"],
5968
lambda_entropy_weight=params["lambda_entropy_weight"],
6069
lambda_sum_weight=params["lambda_sum_weight"],
@@ -76,10 +85,7 @@ def build_model(hp=None):
7685
args = parser.parse_args()
7786

7887
processed_train, processed_validation, processed_test = get_processed_data(
79-
image_size=TARGET_SHAPE,
80-
batch_size=BATCH_SIZE,
81-
use_augmentation=True,
82-
augmentation_factor=2,
88+
image_size=TARGET_SHAPE, batch_size=BATCH_SIZE, use_augmentation=False
8389
)
8490

8591
model = handle_training(

core/seg_tgce/loss/test_tgce.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import tensorflow as tf
2+
3+
from seg_tgce.data.oxford_pet.oxford_pet_tfds import get_data_multiple_annotators_tfds
4+
from seg_tgce.loss.tgce import TgceFeatures, TgcePixel, TgceScalar
5+
6+
if __name__ == "__main__":
7+
NOISE_LEVELS_SNR = [-20.0, 20.0]
8+
TARGET_SHAPE = (128, 128)
9+
BATCH_SIZE = 16
10+
NUM_CLASSES = 3
11+
12+
train_data, val_data, test_data = get_data_multiple_annotators_tfds(
13+
noise_levels_snr=NOISE_LEVELS_SNR,
14+
target_shape=TARGET_SHAPE,
15+
batch_size=BATCH_SIZE,
16+
labeling_rate=0.7,
17+
)
18+
scalar_loss = TgceScalar(num_classes=NUM_CLASSES)
19+
features_loss = TgceFeatures(num_classes=NUM_CLASSES)
20+
pixel_loss = TgcePixel(num_classes=NUM_CLASSES)
21+
22+
for i, data in enumerate(train_data):
23+
ground_truth = data["ground_truth"]
24+
images = data["image"]
25+
masks = data["masks"]
26+
labeler_mask = data["labeler_mask"]
27+
print(ground_truth.shape)
28+
print(images.shape)
29+
print(masks.shape)
30+
print(labeler_mask.shape)
31+
# lambda r with zeroes
32+
scalar_lambda_r = tf.ones((BATCH_SIZE, len(NOISE_LEVELS_SNR)))
33+
features_lambda_r = tf.ones((BATCH_SIZE, 4, 4, len(NOISE_LEVELS_SNR)))
34+
pixel_lambda_r = tf.ones((BATCH_SIZE, *TARGET_SHAPE, len(NOISE_LEVELS_SNR)))
35+
y_pred = tf.random.normal((BATCH_SIZE, *TARGET_SHAPE, NUM_CLASSES))
36+
37+
scalar_loss_value = scalar_loss.call(masks, y_pred, scalar_lambda_r, labeler_mask)
38+
features_loss_value = features_loss.call(
39+
masks, y_pred, features_lambda_r, labeler_mask
40+
)
41+
pixel_loss_value = pixel_loss.call(masks, y_pred, pixel_lambda_r, labeler_mask)
42+
43+
break

core/seg_tgce/loss/tgce.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,11 @@ def call(
188188
y_pred = tf.clip_by_value(y_pred, self.epsilon, 1.0 - self.epsilon)
189189
lambda_r = tf.clip_by_value(lambda_r, self.epsilon, 1.0 - self.epsilon)
190190

191-
reg_term = reliability_penalizer(labeler_mask, lambda_r, self.a, self.b, self.c)
191+
lambda_r_reduced = tf.reduce_mean(lambda_r, axis=(1, 2))
192+
193+
reg_term = reliability_penalizer(
194+
labeler_mask, lambda_r_reduced, self.a, self.b, self.c
195+
)
192196
# Expand predictions to match annotators dimension
193197
y_pred_exp = tf.expand_dims(y_pred, axis=-1)
194198
y_pred_exp = tf.tile(y_pred_exp, [1, 1, 1, 1, tf.shape(y_true)[-1]])
@@ -222,7 +226,7 @@ def call(
222226
)
223227

224228
lambda_sum = self.lambda_sum_weight * tf.reduce_mean(
225-
tf.square(tf.reduce_sum(valid_lambda_r, axis=-2) - 1.0)
229+
tf.square(tf.reduce_sum(valid_lambda_r, axis=-1) - 1.0)
226230
)
227231

228232
total_loss = (
@@ -305,7 +309,12 @@ def call(
305309

306310
y_pred = tf.clip_by_value(y_pred, self.epsilon, 1.0 - self.epsilon)
307311
lambda_r = tf.clip_by_value(lambda_r, self.epsilon, 1.0 - self.epsilon)
308-
reg_term = reliability_penalizer(labeler_mask, lambda_r, self.a, self.b, self.c)
312+
313+
lambda_r_reduced = tf.reduce_mean(lambda_r, axis=(1, 2))
314+
315+
reg_term = reliability_penalizer(
316+
labeler_mask, lambda_r_reduced, self.a, self.b, self.c
317+
)
309318

310319
# Expand predictions to match annotators dimension
311320
y_pred_exp = tf.expand_dims(y_pred, axis=-1)

0 commit comments

Comments
 (0)