1414# limitations under the License.
1515#
1616import logging
17- from typing import List , Optional , Protocol , Union , runtime_checkable
17+ import warnings
18+ from typing import List , Optional , Protocol , Tuple , Union , runtime_checkable
1819
1920import tensorflow as tf
2021from tensorflow .keras .layers import Layer
@@ -77,6 +78,20 @@ class ContrastiveOutput(ModelOutput):
7778 store_negative_ids: bool, optional
7879 Whether to store negative ids for post-processing
7980 by default False
81+ logq_sampling_correction: bool, optional
82+ The LogQ correction is a standard technique for
83+ sampled softmax and popularity-biased sampling.
84+ It subtracts from the logits the
85+ log expected count/prob of the positive and
86+ negative samples in order to not overpenalize the
87+ popular items for being sampled more often as negatives.
88+ It can be enabled if a single negative sampler is provided
89+ and if it provides the sampler provides the
90+ sampling probabilities (i.e. implements with_sampling_probs()).
91+ Another alternative for performing logQ correction is using
92+ ContrastiveOutput(..., post=PopularityLogitsCorrection(item_frequencies)),
93+ where you need to provide the items frequency probability distribution (prior).
94+ Default is False.
8095
8196 References:
8297 ----------
@@ -132,6 +147,7 @@ def __init__(
132147 query_name : str = "query" ,
133148 candidate_name : str = "candidate" ,
134149 store_negative_ids : bool = False ,
150+ logq_sampling_correction : Optional [bool ] = False ,
135151 ** kwargs ,
136152 ):
137153 self .col_schema = None
@@ -168,6 +184,7 @@ def __init__(
168184 self .query_name = query_name
169185 self .candidate_name = candidate_name
170186 self .store_negative_ids = store_negative_ids
187+ self .logq_sampling_correction = logq_sampling_correction
171188
172189 self .target_name = kwargs .pop ("target" , target_name )
173190 super ().__init__ (
@@ -223,7 +240,9 @@ def call_contrastive(self, inputs, features, targets, training=False, testing=Fa
223240 positive = Candidate (id = positive_id , metadata = {** features }).with_embedding (
224241 positive_embedding
225242 )
226- negative = self .sample_negatives (positive , features , training = training , testing = testing )
243+ negative , positive = self .sample_negatives (
244+ positive , features , training = training , testing = testing
245+ )
227246 if self .has_candidate_weights and (
228247 positive .id .shape != negative .id .shape or positive != negative
229248 ):
@@ -264,6 +283,18 @@ def outputs(
264283 tf .multiply (query_embedding , positive .embedding ), keepdims = True , axis = - 1
265284 )
266285
286+ if self .logq_sampling_correction :
287+ if positive .sampling_prob is None or negative .sampling_prob is None :
288+ warnings .warn (
289+ "The logQ sampling correction is enabled, but sampling probs were not found "
290+ "for both positive and negative candidates" ,
291+ RuntimeWarning ,
292+ )
293+
294+ epsilon = 1e-16
295+ positive_scores -= tf .math .log (positive .sampling_prob + epsilon )
296+ negative_scores -= tf .math .log (tf .transpose (negative .sampling_prob + epsilon ))
297+
267298 if self .downscore_false_negatives :
268299 negative_scores , _ = tf_utils .rescore_false_negatives (
269300 positive .id , negative .id , negative_scores , self .false_negative_score
@@ -295,7 +326,7 @@ def sample_negatives(
295326 features : TabularData ,
296327 training = False ,
297328 testing = False ,
298- ) -> Candidate :
329+ ) -> Tuple [ Candidate , Candidate ] :
299330 """Method to sample negatives from `self.negative_samplers`
300331
301332 Parameters
@@ -311,16 +342,28 @@ def sample_negatives(
311342
312343 Returns
313344 -------
314- Items
315- Class containing the sampled negative ids
345+ Tuple[Candidate, Candidate]
346+ Tuple of candidates with sampled negative ids and the provided positive ids
347+ added with the sampling probability
316348 """
317349 sampling_kwargs = {"training" : training , "testing" : testing , "features" : features }
318350 candidates : List [Candidate ] = []
351+
352+ if self .logq_sampling_correction and len (self .negative_samplers ) > 1 :
353+ raise ValueError (
354+ "It is only possible to apply logQ sampling correction "
355+ "(logq_sampling_correction=True) when only one negative sampler is provided."
356+ )
357+
319358 for sampler in self .negative_samplers :
320- sampled : Candidate = tf_utils .call_layer (sampler , positive , ** sampling_kwargs )
359+ neg_samples : Candidate = tf_utils .call_layer (sampler , positive , ** sampling_kwargs )
360+
361+ # Adds to the positive and negative candidates their sampling probs from the sampler
362+ positive = sampler .with_sampling_probs (positive )
363+ neg_samples = sampler .with_sampling_probs (neg_samples )
321364
322- if sampled .id is not None :
323- candidates .append (sampled )
365+ if neg_samples .id is not None :
366+ candidates .append (neg_samples )
324367 else :
325368 LOG .warn (
326369 f"The sampler { type (sampler ).__name__ } returned no samples for this batch."
@@ -336,7 +379,7 @@ def sample_negatives(
336379 for neg in candidates [1 :]:
337380 negatives += neg
338381
339- return negatives
382+ return negatives , positive
340383
341384 def embedding_lookup (self , ids : tf .Tensor ):
342385 return self .to_call .embedding_lookup (tf .squeeze (ids ))
0 commit comments