Skip to content

Commit 5f6b99a

Browse files
Update llm_utils.py
Add more numerical stability to gating strategy.
1 parent 21fbc1f commit 5f6b99a

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

cerebrosllmutils/llm_utils.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -696,37 +696,38 @@ def get_config(self):
696696

697697

698698
# Gating merge layer
699-
700699
@tf.keras.utils.register_keras_serializable(package='cerebrosllmutils', name='GatedMergeLayer')
701700
class GatedMergeLayer(tf.keras.layers.Layer):
702701
"""
703702
Merges two input streams using a learned gating mechanism.
704-
705-
The gate is computed from the first input stream and determines the
706-
proportion of each stream in the final output.
707-
output = gate * input_1 + (1 - gate) * input_2
708-
709-
Args:
710-
d_model (int): The feature dimension of the input streams.
703+
This version is numerically stable to prevent NaN values.
711704
"""
712705
def __init__(self, d_model, **kwargs):
713706
super().__init__(**kwargs)
714707
self.d_model = d_model
715-
# A dense layer to generate the gate values (between 0 and 1)
716-
self.gate_dense = tf.keras.layers.Dense(d_model, activation='sigmoid')
708+
# Initialize gate to start near 0.5 (pass-through)
709+
self.gate_dense = tf.keras.layers.Dense(
710+
d_model,
711+
activation='sigmoid',
712+
bias_initializer=tf.keras.initializers.Constant(0.0)
713+
)
717714

718715
def call(self, inputs):
719716
input_1, input_2 = inputs
720-
# Generate gate from the first input
721717
gate_values = self.gate_dense(input_1)
722-
# Blend the two streams
723-
return gate_values * input_1 + (1.0 - gate_values) * input_2
718+
719+
# Add epsilon to prevent exact 0/1 values and numerical instability
720+
gate_values = tf.clip_by_value(gate_values, 1e-7, 1 - 1e-7)
721+
722+
# Use tf.add for numerical stability
723+
return tf.add(
724+
tf.multiply(gate_values, input_1),
725+
tf.multiply(1.0 - gate_values, input_2)
726+
)
724727

725728
def get_config(self):
726729
config = super().get_config()
727-
config.update({
728-
"d_model": self.d_model,
729-
})
730+
config.update({"d_model": self.d_model})
730731
return config
731732

732733

@@ -1472,7 +1473,6 @@ def call(self, inputs, training=False):
14721473
attn_output = self.dropout1(attn_output, training=training)
14731474

14741475
# 4. *** CHANGE: GATE the original input and the attention output using the standard layer ***
1475-
# This replaces the old manual gating logic.
14761476
merged_output = self.gate([inputs, attn_output])
14771477

14781478
# --- Feed-Forward Sub-layer with Pre-LN and Residual ---

0 commit comments

Comments
 (0)