Skip to content

Commit a8df9e4

Browse files
authored
fix logits temperature (#868)
1 parent a507022 commit a8df9e4

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

merlin/models/tf/outputs/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def __call__(self, inputs, *args, **kwargs):
140140

141141
if getattr(self, "logits_scaler", None):
142142
if isinstance(outputs, tf.Tensor):
143-
outputs = Prediction(outputs)
143+
targets = kwargs.pop("targets", None)
144+
outputs = Prediction(outputs, targets)
144145
outputs = self.logits_scaler(outputs)
145146

146147
return outputs

tests/unit/tf/models/test_base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,3 +860,29 @@ def _compute_model_metrics(
860860
assert not np.isclose(loss1, loss_masked)
861861
for k in metrics1:
862862
assert not np.isclose(metrics1[k], metrics_masked[k])
863+
864+
865+
def test_categorical_prediction_with_temperature(sequence_testing_data: Dataset):
866+
train = sequence_testing_data
867+
train.schema = train.schema.select_by_name(["item_id_seq", "user_country"])
868+
schema_model = train.schema.select_by_name(["item_id_seq"])
869+
inputs = mm.InputBlockV2(
870+
schema_model,
871+
embeddings=mm.Embeddings(
872+
schema_model,
873+
),
874+
)
875+
model = mm.Model(
876+
inputs,
877+
mm.MLPBlock([32]),
878+
mm.CategoricalOutput(
879+
to_call=train.schema.select_by_name(["user_country"]), logits_temperature=0.2
880+
),
881+
)
882+
883+
loader = mm.Loader(
884+
train, batch_size=1024, transform=mm.ToTarget(train.schema, "user_country", one_hot=True)
885+
)
886+
887+
model.compile(run_eagerly=False, optimizer="adam")
888+
model.fit(loader, batch_size=1024, epochs=1)

0 commit comments

Comments
 (0)