Skip to content

Commit 8ac9090

Browse files
authored
fix contrastive output (#800)
1 parent abab6e1 commit 8ac9090

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

merlin/models/tf/losses/pairwise.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
7272
tf.Tensor
7373
Loss per example
7474
"""
75+
tf.assert_equal(tf.rank(y_true), 2, f"Targets must be 2-D tensor (got {y_true.shape})")
76+
77+
tf.assert_equal(tf.rank(y_pred), 2, f"Predictions must be 2-D tensor (got {y_pred.shape})")
78+
7579
(
7680
positives_scores,
7781
negatives_scores,

merlin/models/tf/outputs/contrastive.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def outputs(
226226
# To ensure that the output is always fp32, avoiding numerical
227227
# instabilities with mixed_float16 policy
228228
outputs = tf.cast(outputs, tf.float32)
229-
outputs = tf.squeeze(outputs)
230229

231230
targets = tf.concat(
232231
[

tests/unit/tf/outputs/test_contrastive.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,21 @@ def test_contrastive_only_positive_when_not_training(ecommerce_data: Dataset):
190190
)
191191

192192

193+
@pytest.mark.parametrize("run_eagerly", [True, False])
194+
def test_contrastive_output_with_pairwise_loss(ecommerce_data: Dataset, run_eagerly):
195+
model = mm.RetrievalModelV2(
196+
query=mm.Encoder(ecommerce_data.schema.select_by_tag(Tags.USER), mm.MLPBlock([2])),
197+
candidate=mm.Encoder(ecommerce_data.schema.select_by_tag(Tags.ITEM), mm.MLPBlock([2])),
198+
output=mm.ContrastiveOutput(
199+
ecommerce_data.schema.select_by_tag(Tags.ITEM_ID),
200+
negative_samplers="in-batch",
201+
candidate_name="item",
202+
),
203+
)
204+
model.compile(run_eagerly=run_eagerly, loss="bpr-max")
205+
_ = model.fit(ecommerce_data, batch_size=50, epochs=1)
206+
207+
193208
def _retrieval_inputs_(batch_size):
194209
users_embeddings = tf.random.uniform(shape=(batch_size, 5), dtype=tf.float32)
195210
items_embeddings = tf.random.uniform(shape=(batch_size, 5), dtype=tf.float32)

0 commit comments

Comments
 (0)