Skip to content

Commit 42a9beb

Browse files
committed
core\refac: #98 full hp (scalar)
- full hps availability for scalar experiment - added missing deps for optuna meaningful hps plotting
1 parent 4fed535 commit 42a9beb

File tree

7 files changed

+400
-41
lines changed

7 files changed

+400
-41
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ core/dist
1111
__data__
1212
**/results/**
1313
tuner_results/
14-
**/vis/
14+
**/vis/
15+
**/optuna_results/

core/poetry.lock

Lines changed: 299 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

core/pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
description = "Framework for handling image segmentation in the context of multiple annotators"
33
name = "seg_tgce"
4-
version = "0.3.4"
4+
version = "0.3.7"
55
readme = "README.md"
66
authors = [{ name = "Brandon Lotero", email = "blotero@gmail.com" }]
77
maintainers = [{ name = "Brandon Lotero", email = "blotero@gmail.com" }]
@@ -15,7 +15,7 @@ Issues = "https://github.com/blotero/seg_tgce/issues"
1515

1616
[tool.poetry]
1717
name = "seg_tgce"
18-
version = "0.3.4"
18+
version = "0.3.7"
1919
authors = ["Brandon Lotero <blotero@gmail.com>"]
2020
description = "A package for the SEG TGCE project"
2121
readme = "README.md"
@@ -44,6 +44,8 @@ albumentations = "^2.0.7"
4444
pandas = "^2.2.3"
4545
seaborn = "^0.13.2"
4646
optuna = "^4.4.0"
47+
plotly = "^6.2.0"
48+
kaleido = "^1.0.0"
4749

4850

4951
[tool.poetry.group.test.dependencies]

core/seg_tgce/experiments/histology/scalar.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
"noise_tolerance": 0.5,
2828
"a": 0.3,
2929
"b": 0.7,
30+
"c": 1.0,
31+
"lambda_reg_weight": 0.1,
32+
"lambda_entropy_weight": 0.1,
33+
"lambda_sum_weight": 0.1,
3034
}
3135

3236

@@ -40,16 +44,27 @@ def build_model(hp: kt.HyperParameters | None = None) -> tf.keras.Model:
4044
"noise_tolerance": hp.Float(
4145
"noise_tolerance", min_value=0.1, max_value=0.9, step=0.1
4246
),
43-
"b": hp.Float("b", min_value=0.1, max_value=1.0, step=0.1),
4447
"a": hp.Float("a", min_value=0.1, max_value=1.0, step=0.1),
48+
"b": hp.Float("b", min_value=0.1, max_value=1.0, step=0.1),
49+
"c": hp.Float("c", min_value=0.1, max_value=10.0, step=0.1),
50+
"lambda_reg_weight": hp.Float(
51+
"lambda_reg_weight", min_value=0.0, max_value=10.0, step=0.1
52+
),
53+
"lambda_entropy_weight": hp.Float(
54+
"lambda_entropy_weight", min_value=0.0, max_value=10.0, step=0.1
55+
),
4556
}
4657

4758
return build_scalar_model_from_hparams(
4859
learning_rate=params["initial_learning_rate"],
4960
q=params["q"],
5061
noise_tolerance=params["noise_tolerance"],
51-
b=params["b"],
5262
a=params["a"],
63+
b=params["b"],
64+
c=params["c"],
65+
lambda_reg_weight=params["lambda_reg_weight"],
66+
lambda_entropy_weight=params["lambda_entropy_weight"],
67+
lambda_sum_weight=params["lambda_sum_weight"],
5368
num_classes=N_CLASSES,
5469
target_shape=TARGET_SHAPE,
5570
n_scorers=N_REAL_SCORERS,

core/seg_tgce/experiments/pets/scalar.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
"noise_tolerance": 0.5,
3131
"a": 0.2,
3232
"b": 0.7,
33+
"c": 1.0,
34+
"lambda_reg_weight": 0.1,
35+
"lambda_entropy_weight": 0.1,
36+
"lambda_sum_weight": 0.1,
3337
}
3438

3539

@@ -40,7 +44,11 @@ def build_model_from_trial(trial: HpTunerTrial | None) -> Model:
4044
q=DEFAULT_HPARAMS["q"],
4145
noise_tolerance=DEFAULT_HPARAMS["noise_tolerance"],
4246
b=DEFAULT_HPARAMS["b"],
47+
c=DEFAULT_HPARAMS["c"],
4348
a=DEFAULT_HPARAMS["a"],
49+
lambda_reg_weight=DEFAULT_HPARAMS["lambda_reg_weight"],
50+
lambda_entropy_weight=DEFAULT_HPARAMS["lambda_entropy_weight"],
51+
lambda_sum_weight=DEFAULT_HPARAMS["lambda_sum_weight"],
4452
num_classes=NUM_CLASSES,
4553
target_shape=TARGET_SHAPE,
4654
n_scorers=NUM_SCORERS,
@@ -52,6 +60,12 @@ def build_model_from_trial(trial: HpTunerTrial | None) -> Model:
5260
noise_tolerance=trial.suggest_float("noise_tolerance", 0.1, 0.9, step=0.01),
5361
b=trial.suggest_float("b", 0.1, 1.0, step=0.01),
5462
a=trial.suggest_float("a", 0.1, 1.0, step=0.01),
63+
c=trial.suggest_float("c", 0.1, 10.0, step=0.1),
64+
lambda_reg_weight=trial.suggest_float("lambda_reg_weight", 0.0, 10.0, step=0.1),
65+
lambda_entropy_weight=trial.suggest_float(
66+
"lambda_entropy_weight", 0.0, 10.0, step=0.1
67+
),
68+
lambda_sum_weight=trial.suggest_float("lambda_sum_weight", 0.0, 10.0, step=0.1),
5569
num_classes=NUM_CLASSES,
5670
target_shape=TARGET_SHAPE,
5771
n_scorers=NUM_SCORERS,

core/seg_tgce/loss/tgce.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@
99

1010

1111
def safe_divide(numerator: Tensor, denominator: Tensor, epsilon: float = 1e-8) -> Tensor:
12-
"""Safely divide two tensors, avoiding division by zero."""
1312
return tf.math.divide(
1413
numerator, tf.clip_by_value(denominator, epsilon, tf.reduce_max(denominator))
1514
)
1615

1716

1817
def safe_pow(x: Tensor, p: Tensor, epsilon: float = 1e-8) -> Tensor:
19-
"""Compute x^p safely by ensuring x is within a valid range."""
2018
return tf.pow(tf.clip_by_value(x, epsilon, 1.0 - epsilon), p)
2119

2220

23-
class TcgeScalar(Loss):
21+
def reliability_penalizer(
22+
lms: Tensor, lambdas: Tensor, a: float, b: float, c: float
23+
) -> Tensor:
24+
x = lambdas - lms
25+
return c * tf.maximum(1 / (1 - a) * x * tf.exp((x - 1) / b), 0)
26+
27+
28+
class TgceScalar(Loss):
2429
"""
2530
Truncated generalized cross entropy
2631
for semantic segmentation loss.
@@ -35,32 +40,40 @@ def __init__( # pylint: disable=too-many-arguments
3540
noise_tolerance: float = 0.1,
3641
a: float = 0.7,
3742
b: float = 0.7,
43+
c: float = 1.0,
44+
lambda_reg_weight: float = 0.1,
45+
lambda_entropy_weight: float = 0.1,
46+
lambda_sum_weight: float = 0.1,
3847
epsilon: float = 1e-8,
3948
) -> None:
4049
self.q = q
4150
self.num_classes = num_classes
4251
self.noise_tolerance = noise_tolerance
4352
self.a = a
4453
self.b = b
54+
self.c = c
55+
self.lambda_reg_weight = lambda_reg_weight
56+
self.lambda_entropy_weight = lambda_entropy_weight
57+
self.lambda_sum_weight = lambda_sum_weight
4558
self.epsilon = epsilon
4659
super().__init__(name=name)
4760

48-
def penalizer(self, lms: tf.Tensor, lambdas: tf.Tensor) -> tf.Tensor:
49-
"""Compute the penalizer term for reliability regularization."""
50-
x = lambdas - lms
51-
return tf.maximum(1 / (1 - self.a) * x * tf.exp((x - 1) / self.b), 0)
52-
5361
def call(
5462
self,
5563
y_true: tf.Tensor,
5664
y_pred: tf.Tensor,
5765
lambda_r: tf.Tensor,
5866
labeler_mask: tf.Tensor,
5967
) -> tf.Tensor:
68+
# Cast inputs to target data type
69+
y_true = tf.cast(y_true, TARGET_DATA_TYPE)
70+
y_pred = tf.cast(y_pred, TARGET_DATA_TYPE)
71+
lambda_r = tf.cast(lambda_r, TARGET_DATA_TYPE)
72+
6073
y_pred = tf.clip_by_value(y_pred, self.epsilon, 1.0 - self.epsilon)
6174
lambda_r = tf.clip_by_value(lambda_r, self.epsilon, 1.0 - self.epsilon)
6275

63-
reg_term = self.penalizer(labeler_mask, lambda_r)
76+
reg_term = reliability_penalizer(labeler_mask, lambda_r, self.a, self.b, self.c)
6477

6578
y_pred_exp = tf.expand_dims(y_pred, axis=-1)
6679
y_pred_exp = tf.tile(y_pred_exp, [1, 1, 1, 1, tf.shape(y_true)[-1]])
@@ -78,7 +91,28 @@ def call(
7891
(1.0 - tf.pow(self.noise_tolerance, self.q)) / (self.q + self.epsilon)
7992
)
8093

81-
total_loss = tf.reduce_mean(term1 + term2) + reg_term
94+
# Only compute regularization terms for valid labelers
95+
valid_lambda_r = lambda_r * tf.expand_dims(tf.expand_dims(labeler_mask, 1), 1)
96+
lambda_reg = self.lambda_reg_weight * tf.reduce_mean(
97+
tf.square(valid_lambda_r - 0.5)
98+
)
99+
100+
lambda_entropy = -self.lambda_entropy_weight * tf.reduce_mean(
101+
valid_lambda_r * tf.math.log1p(valid_lambda_r)
102+
+ (1 - valid_lambda_r) * tf.math.log1p(1 - valid_lambda_r)
103+
)
104+
105+
lambda_sum = self.lambda_sum_weight * tf.reduce_mean(
106+
tf.square(tf.reduce_sum(valid_lambda_r, axis=-1) - 1.0)
107+
)
108+
109+
total_loss = (
110+
tf.reduce_mean(term1 + term2)
111+
+ reg_term
112+
+ lambda_reg
113+
+ lambda_entropy
114+
+ lambda_sum
115+
)
82116

83117
total_loss = tf.where(
84118
tf.math.is_nan(total_loss),
@@ -99,11 +133,14 @@ def get_config(
99133
**base_config,
100134
"q": self.q,
101135
"b": self.b,
136+
"lambda_reg_weight": self.lambda_reg_weight,
137+
"lambda_entropy_weight": self.lambda_entropy_weight,
138+
"lambda_sum_weight": self.lambda_sum_weight,
102139
"epsilon": self.epsilon,
103140
}
104141

105142

106-
class TcgeFeatures(Loss):
143+
class TgceFeatures(Loss):
107144
"""
108145
Truncated generalized cross entropy for semantic segmentation loss
109146
with feature-based reliability (reliability map from bottleneck features).
@@ -210,7 +247,7 @@ def get_config(
210247
}
211248

212249

213-
class TcgePixel(Loss):
250+
class TgcePixel(Loss):
214251
"""
215252
Truncated generalized cross entropy for semantic segmentation loss
216253
with pixel-wise reliability (full resolution reliability map).

core/seg_tgce/models/builders.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from keras import Model
22
from keras.optimizers import Adam
33

4-
from seg_tgce.loss.tgce import TcgeFeatures, TcgePixel, TcgeScalar, TgceBaseline
4+
from seg_tgce.loss.tgce import TgceBaseline, TgceFeatures, TgcePixel, TgceScalar
55
from seg_tgce.metrics import DiceCoefficient, JaccardCoefficient
66
from seg_tgce.models.unet import (
77
unet_baseline,
@@ -55,39 +55,32 @@ def build_baseline_model_from_hparams(
5555

5656

5757
def build_scalar_model_from_hparams(
58+
*,
5859
learning_rate: float,
5960
q: float,
6061
noise_tolerance: float,
61-
b: float,
6262
a: float,
63+
b: float,
64+
c: float,
65+
lambda_reg_weight: float,
66+
lambda_entropy_weight: float,
67+
lambda_sum_weight: float,
6368
num_classes: int,
6469
target_shape: tuple,
6570
n_scorers: int,
6671
) -> Model:
67-
"""Build the scalar model with direct hyperparameter values.
68-
69-
Args:
70-
learning_rate: Learning rate for the optimizer
71-
q: q parameter for TGCE loss
72-
noise_tolerance: Noise tolerance parameter for TGCE loss
73-
lambda_reg_weight: Regularization weight for TGCE loss
74-
lambda_entropy_weight: Entropy weight for TGCE loss
75-
lambda_sum_weight: Sum weight for TGCE loss
76-
num_classes: Number of classes in the segmentation
77-
target_shape: Target shape of input images
78-
n_scorers: Number of annotators/scorers
79-
80-
Returns:
81-
Compiled Keras model
82-
"""
8372
optimizer = Adam(learning_rate=learning_rate)
8473

85-
loss_fn = TcgeScalar(
74+
loss_fn = TgceScalar(
8675
num_classes=num_classes,
8776
q=q,
8877
noise_tolerance=noise_tolerance,
89-
b=b,
9078
a=a,
79+
b=b,
80+
c=c,
81+
lambda_reg_weight=lambda_reg_weight,
82+
lambda_entropy_weight=lambda_entropy_weight,
83+
lambda_sum_weight=lambda_sum_weight,
9184
name="TGCE",
9285
)
9386

@@ -145,7 +138,7 @@ def build_features_model_from_hparams(
145138
"""
146139
optimizer = Adam(learning_rate=learning_rate)
147140

148-
loss_fn = TcgeFeatures(
141+
loss_fn = TgceFeatures(
149142
num_classes=num_classes,
150143
q=q,
151144
noise_tolerance=noise_tolerance,
@@ -193,7 +186,7 @@ def build_pixel_model_from_hparams(
193186
) -> Model:
194187
optimizer = Adam(learning_rate=learning_rate)
195188

196-
loss_fn = TcgePixel(
189+
loss_fn = TgcePixel(
197190
num_classes=num_classes,
198191
q=q,
199192
noise_tolerance=noise_tolerance,

0 commit comments

Comments
 (0)