Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ deel_lip_dev_env
logs
export-wass
site
docs/notebooks/save_img
wandb
135 changes: 135 additions & 0 deletions deel/lip/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,141 @@ def get_config(self):
return dict(list(base_config.items()) + list(config.items()))


@register_keras_serializable("deel-lip", "MulticlassSoftHKR")
class MulticlassSoftHKR(Loss):
def __init__(
self,
alpha=10.0,
min_margin=1.0,
alpha_mean=0.99,
temperature=1.0,
reduction=Reduction.AUTO,
name="MulticlassSoftHKR",
):
"""
The multiclass version of HKR with softmax. This is done by computing
the HKR term over each class and averaging the results.

Note that `y_true` could be either one-hot encoded, +/-1 values.


Args:
alpha (float): regularization factor
min_margin (float): margin to enforce.
alpha_mean (float): geometric mean factor
temperature (float): factor for softmax temperature
(higher value increases the weight of the highest non y_true logits)
reduction: passed to tf.keras.Loss constructor
name (str): passed to tf.keras.Loss constructor

"""
self.alpha = tf.Variable(alpha, dtype=tf.float32)
self.min_margin_v = min_margin
self.alpha_mean = alpha_mean

self.current_mean = tf.Variable(
(self.min_margin_v,),
dtype=tf.float32,
constraint=lambda x: tf.clip_by_value(x, 0.005, 1000),
name="current_mean",
)

self.temperature = temperature * self.min_margin_v
if alpha == np.inf: # alpha = inf => hinge only
self.fct = self.multiclass_hinge_soft
else:
self.fct = self.hkr

super(MulticlassSoftHKR, self).__init__(reduction=reduction, name=name)

@tf.function
def _update_mean(self, y_pred):
current_global_mean = tf.cast(
tf.reduce_mean(tf.abs(y_pred)), self.current_mean.dtype
)
current_global_mean = (
self.alpha_mean * self.current_mean
+ (1 - self.alpha_mean) * current_global_mean
)
self.current_mean.assign(current_global_mean)
total_mean = current_global_mean
total_mean = tf.clip_by_value(total_mean, self.min_margin_v, 20000)
return total_mean

def computeTemperatureSoftMax(self, y_true, y_pred):
total_mean = self._update_mean(y_pred)
current_temperature = tf.cast(
tf.stop_gradient(
tf.clip_by_value(self.temperature / total_mean, 0.005, 250)
),
y_pred.dtype,
)

opposite_values = tf.where(
y_true > 0, -y_pred.dtype.max, current_temperature * y_pred
)
F_soft_KR = tf.nn.softmax(opposite_values)
F_soft_KR = tf.where(y_true > 0, tf.cast(1.0, F_soft_KR.dtype), F_soft_KR)
return F_soft_KR

def signed_y_pred(self, y_true, y_pred):
"""Return for each item sign(y_true)*y_pred."""
sign_y_true = tf.where(y_true > 0, 1, -1) # switch to +/-1
sign_y_true = tf.cast(sign_y_true, y_pred.dtype)
return y_pred * sign_y_true

def multiclass_hinge_preproc(self, signed_y_pred, min_margin):
"""From multiclass_hinge(y_true, y_pred, min_margin)
simplified to use precalculated signed_y_pred"""
# compute the elementwise hinge term
hinge = tf.nn.relu(min_margin / 2.0 - signed_y_pred)
return hinge

@tf.function
def multiclass_hinge_soft(self, y_true, y_pred):
F_soft_KR = self.computeTemperatureSoftMax(y_true, y_pred)
signed_y_pred = self.signed_y_pred(y_true, y_pred)
hinge = self.multiclass_hinge_preproc(signed_y_pred, self.min_margin_v)
b = hinge * F_soft_KR
return b

# @tf.function
def hkr(self, y_true, y_pred):
F_soft_KR = self.computeTemperatureSoftMax(y_true, y_pred)
signed_y_pred = self.signed_y_pred(y_true, y_pred)
kr = -signed_y_pred
a = kr * F_soft_KR
a = tf.reduce_sum(a, axis=-1)

hinge = self.multiclass_hinge_preproc(signed_y_pred, self.min_margin_v)

b = hinge * F_soft_KR
b = tf.reduce_sum(b, axis=-1)

# tf.print(self.alpha)
beta = 1.0 / self.alpha
# Hinge with coef 1 and hkr with lower coef a/self.alpha + b
return beta * a + b

def call(self, y_true, y_pred):
if not (isinstance(y_pred, tf.Tensor)): # required for dtype.max
y_pred = tf.convert_to_tensor(y_pred, dtype=y_pred.dtype)
if not (isinstance(y_true, tf.Tensor)):
y_true = tf.convert_to_tensor(y_true, dtype=y_pred.dtype)
return self.fct(y_true, y_pred)

def get_config(self):
config = {
"alpha": self.alpha.numpy(),
"min_margin": self.min_margin_v,
"alpha_mean": self.alpha_mean,
"temperature": self.temperature
/ self.min_margin_v, # consistency with the __init__
}
base_config = super(MulticlassSoftHKR, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


@register_keras_serializable("deel-lip", "MultiMargin")
class MultiMargin(Loss):
def __init__(self, min_margin=1.0, reduction=Reduction.AUTO, name="MultiMargin"):
Expand Down
Loading