Skip to content

Commit 2a20547

Browse files
gabrielspmoreiramarcromeynedknv
authored
Refactory/fix of sampled softmax on PopularityBasedSamplerV2 / ContrastiveOutput / Candidate (#1051)
Co-authored-by: Marc Romeyn <[email protected]> Co-authored-by: edknv <[email protected]>
1 parent 78f2732 commit 2a20547

File tree

6 files changed

+157
-21
lines changed

6 files changed

+157
-21
lines changed

merlin/models/tf/models/retrieval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def YoutubeDNNRetrievalModelV2(
579579
negative_samplers=PopularityBasedSamplerV2(
580580
max_num_samples=num_sampled, max_id=num_classes - 1, min_id=min_sampled_id
581581
),
582+
logq_sampling_correction=True,
582583
)
583584

584585
return RetrievalModelV2(query=query, output=outputs)

merlin/models/tf/outputs/contrastive.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
# limitations under the License.
1515
#
1616
import logging
17-
from typing import List, Optional, Protocol, Union, runtime_checkable
17+
import warnings
18+
from typing import List, Optional, Protocol, Tuple, Union, runtime_checkable
1819

1920
import tensorflow as tf
2021
from tensorflow.keras.layers import Layer
@@ -77,6 +78,20 @@ class ContrastiveOutput(ModelOutput):
7778
store_negative_ids: bool, optional
7879
Whether to store negative ids for post-processing
7980
by default False
81+
logq_sampling_correction: bool, optional
82+
The LogQ correction is a standard technique for
83+
sampled softmax and popularity-biased sampling.
84+
It subtracts from the logits the
85+
log expected count/prob of the positive and
86+
negative samples in order to not overpenalize the
87+
popular items for being sampled more often as negatives.
88+
It can be enabled if a single negative sampler is provided
89+
and if it provides the sampler provides the
90+
sampling probabilities (i.e. implements with_sampling_probs()).
91+
Another alternative for performing logQ correction is using
92+
ContrastiveOutput(..., post=PopularityLogitsCorrection(item_frequencies)),
93+
where you need to provide the items frequency probability distribution (prior).
94+
Default is False.
8095
8196
References:
8297
----------
@@ -132,6 +147,7 @@ def __init__(
132147
query_name: str = "query",
133148
candidate_name: str = "candidate",
134149
store_negative_ids: bool = False,
150+
logq_sampling_correction: Optional[bool] = False,
135151
**kwargs,
136152
):
137153
self.col_schema = None
@@ -168,6 +184,7 @@ def __init__(
168184
self.query_name = query_name
169185
self.candidate_name = candidate_name
170186
self.store_negative_ids = store_negative_ids
187+
self.logq_sampling_correction = logq_sampling_correction
171188

172189
self.target_name = kwargs.pop("target", target_name)
173190
super().__init__(
@@ -223,7 +240,9 @@ def call_contrastive(self, inputs, features, targets, training=False, testing=Fa
223240
positive = Candidate(id=positive_id, metadata={**features}).with_embedding(
224241
positive_embedding
225242
)
226-
negative = self.sample_negatives(positive, features, training=training, testing=testing)
243+
negative, positive = self.sample_negatives(
244+
positive, features, training=training, testing=testing
245+
)
227246
if self.has_candidate_weights and (
228247
positive.id.shape != negative.id.shape or positive != negative
229248
):
@@ -264,6 +283,18 @@ def outputs(
264283
tf.multiply(query_embedding, positive.embedding), keepdims=True, axis=-1
265284
)
266285

286+
if self.logq_sampling_correction:
287+
if positive.sampling_prob is None or negative.sampling_prob is None:
288+
warnings.warn(
289+
"The logQ sampling correction is enabled, but sampling probs were not found "
290+
"for both positive and negative candidates",
291+
RuntimeWarning,
292+
)
293+
294+
epsilon = 1e-16
295+
positive_scores -= tf.math.log(positive.sampling_prob + epsilon)
296+
negative_scores -= tf.math.log(tf.transpose(negative.sampling_prob + epsilon))
297+
267298
if self.downscore_false_negatives:
268299
negative_scores, _ = tf_utils.rescore_false_negatives(
269300
positive.id, negative.id, negative_scores, self.false_negative_score
@@ -295,7 +326,7 @@ def sample_negatives(
295326
features: TabularData,
296327
training=False,
297328
testing=False,
298-
) -> Candidate:
329+
) -> Tuple[Candidate, Candidate]:
299330
"""Method to sample negatives from `self.negative_samplers`
300331
301332
Parameters
@@ -311,16 +342,28 @@ def sample_negatives(
311342
312343
Returns
313344
-------
314-
Items
315-
Class containing the sampled negative ids
345+
Tuple[Candidate, Candidate]
346+
Tuple of candidates with sampled negative ids and the provided positive ids
347+
added with the sampling probability
316348
"""
317349
sampling_kwargs = {"training": training, "testing": testing, "features": features}
318350
candidates: List[Candidate] = []
351+
352+
if self.logq_sampling_correction and len(self.negative_samplers) > 1:
353+
raise ValueError(
354+
"It is only possible to apply logQ sampling correction "
355+
"(logq_sampling_correction=True) when only one negative sampler is provided."
356+
)
357+
319358
for sampler in self.negative_samplers:
320-
sampled: Candidate = tf_utils.call_layer(sampler, positive, **sampling_kwargs)
359+
neg_samples: Candidate = tf_utils.call_layer(sampler, positive, **sampling_kwargs)
360+
361+
# Adds to the positive and negative candidates their sampling probs from the sampler
362+
positive = sampler.with_sampling_probs(positive)
363+
neg_samples = sampler.with_sampling_probs(neg_samples)
321364

322-
if sampled.id is not None:
323-
candidates.append(sampled)
365+
if neg_samples.id is not None:
366+
candidates.append(neg_samples)
324367
else:
325368
LOG.warn(
326369
f"The sampler {type(sampler).__name__} returned no samples for this batch."
@@ -336,7 +379,7 @@ def sample_negatives(
336379
for neg in candidates[1:]:
337380
negatives += neg
338381

339-
return negatives
382+
return negatives, positive
340383

341384
def embedding_lookup(self, ids: tf.Tensor):
342385
return self.to_call.embedding_lookup(tf.squeeze(ids))

merlin/models/tf/outputs/sampling/base.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,16 @@ class Candidate(NamedTuple):
3030
----------
3131
id : tf.Tensor
3232
The tensor of item ids
33+
sampling_prob : tf.Tensor
34+
Useful for logQ correction, based on the sampling distribution
3335
metadata:
3436
dictionary of tensors containing meta information
3537
about items such as item embeddings and item category
3638
"""
3739

3840
id: tf.Tensor
3941
metadata: Dict[str, tf.Tensor]
42+
sampling_prob: Optional[tf.Tensor] = None
4043

4144
@property
4245
def embedding(self) -> tf.Tensor:
@@ -51,6 +54,9 @@ def with_embedding(self, embedding: tf.Tensor) -> "Candidate":
5154

5255
return self
5356

57+
def with_sampling_prob(self, sampling_prob: tf.Tensor) -> "Candidate":
58+
return Candidate(id=self.id, metadata=self.metadata, sampling_prob=sampling_prob)
59+
5460
def __add__(self, other):
5561
metadata = {}
5662
for key in self.metadata:
@@ -68,12 +74,12 @@ def shape(self) -> "Candidate":
6874
def __repr__(self):
6975
metadata = {key: str(val) for key, val in self.metadata.items()}
7076

71-
return f"Candidate({self.id}, {metadata})"
77+
return f"Candidate({self.id}, {self.sampling_prob}, {metadata})"
7278

7379
def __str__(self):
7480
metadata = {key: str(val) for key, val in self.metadata.items()}
7581

76-
return f"Candidate({self.id}, {metadata})"
82+
return f"Candidate({self.id}, {self.sampling_prob}, {metadata})"
7783

7884
def __eq__(self, other) -> bool:
7985
if self.id.shape != other.id.shape:
@@ -84,15 +90,17 @@ def __eq__(self, other) -> bool:
8490
def get_config(self):
8591
return {
8692
"id": self.id,
93+
"sampling_prob": self.sampling_prob,
8794
"metadata": self.metadata,
8895
}
8996

9097
@classmethod
9198
def from_config(cls, config):
9299
ids = config["config"]["id"]
100+
sampling_prob = config["config"]["sampling_prob"]
93101
metadata = config["config"]["metadata"]
94102

95-
return cls(ids, metadata)
103+
return cls(ids, sampling_prob, metadata)
96104

97105

98106
negative_sampling_registry: Registry = Registry.class_registry("tf.negative_sampling")
@@ -139,6 +147,9 @@ def add(self, items: Candidate):
139147
def sample(self) -> Candidate:
140148
raise NotImplementedError()
141149

150+
def with_sampling_probs(self, items: Candidate) -> Candidate:
151+
return items
152+
142153
@property
143154
def max_num_samples(self) -> int:
144155
return self._max_num_samples

merlin/models/tf/outputs/sampling/popularity.py

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,25 @@
2323
@tf.keras.utils.register_keras_serializable(package="merlin.models")
2424
class PopularityBasedSamplerV2(CandidateSampler):
2525
"""
26-
Provides a popularity-based negative sampling for the softmax layer
26+
Provides a popularity-based negative sampling for sampled softmax [1]_ [2]_.
2727
to ensure training efficiency when the catalog of items is very large.
28-
The capacity of the queue is fixed and is equal to the catalog size.
28+
Items are sampled from the whole catalog. It also allows saving
29+
the sampling probabilities for both positive and negative candidates,
30+
that are required by the logQ sampling correction of sampled softmax.
31+
This class do not require the actual frequency of items. It assumes that
32+
item ids are sorted by frequency and follow a long tail distribution and
33+
uses tf.random.log_uniform_candidate_sampler() for sampling the candidate ids.
34+
35+
References
36+
----------
37+
.. [1] Yoshua Bengio and Jean-Sébastien Sénécal. 2003. Quick Training of Probabilistic
38+
Neural Nets by Importance Sampling. In Proceedings of the conference on Artificial
39+
Intelligence and Statistics (AISTATS).
40+
41+
.. [2] Y. Bengio and J. S. Senecal. 2008. Adaptive Importance Sampling to Accelerate
42+
Training of a Neural Probabilistic Language Model. Trans. Neur. Netw. 19, 4 (April
43+
2008), 713–722. https://doi.org/10.1109/TNN.2007.912312
44+
2945
3046
Parameters
3147
----------
@@ -38,6 +54,8 @@ class PopularityBasedSamplerV2(CandidateSampler):
3854
Defaults to 0.
3955
max_num_samples: int
4056
The number of unique negatives to sample at each batch.
57+
unique: True
58+
Whether to return unique candidate ids or allow for repeated ones
4159
seed: int
4260
Fix the random values returned by the sampler to ensure reproducibility
4361
Defaults to None
@@ -48,13 +66,17 @@ def __init__(
4866
max_id: int,
4967
min_id: int = 0,
5068
max_num_samples: int = 10,
69+
unique: Optional[bool] = True,
5170
seed: Optional[int] = None,
5271
**kwargs,
5372
):
5473
super().__init__(max_num_samples=max_num_samples, **kwargs)
5574
self.max_id = max_id
5675
self.min_id = min_id
5776
self.seed = seed
77+
self.unique = unique
78+
79+
self.sampling_dist = self.get_sampling_distribution()
5880

5981
assert (
6082
self.max_num_samples <= self.max_id
@@ -91,22 +113,79 @@ def sample(self) -> Candidate:
91113
Items
92114
The negative items ids
93115
"""
94-
sampled_ids, _, _ = tf.random.log_uniform_candidate_sampler(
116+
(
117+
sampled_ids,
118+
_,
119+
_,
120+
) = tf.random.log_uniform_candidate_sampler(
121+
# This is just a placeholder for true_classestrue classes.
122+
# It should be provided the positive ids here if wanted to
123+
# get the expected count probs returned.
124+
# We rather make usage of CandidateSampler.with_sampling_probs()
125+
# method to get the sampling probs from positives and negatives
95126
true_classes=tf.ones((1, 1), dtype=tf.int64),
96127
num_true=1,
97128
num_sampled=self.max_num_samples,
98-
unique=True,
129+
unique=self.unique,
99130
range_max=self.max_id - self.min_id,
100131
seed=self.seed,
101132
)
102-
103133
# Shifting the sampled ids to ignore the first ids (usually reserved for nulls, OOV)
104134
sampled_ids += self.min_id
105-
106135
sampled_ids = tf.expand_dims(sampled_ids, -1)
107136

137+
sampled_ids = tf.stop_gradient(sampled_ids)
138+
108139
return Candidate(id=sampled_ids, metadata={})
109140

141+
def get_sampling_distribution(self) -> tf.Tensor:
142+
"""Returns the approximated distribution used to sample items
143+
by using tf.random.log_uniform_candidate_sampler()
144+
145+
Returns
146+
-------
147+
tf.Tensor
148+
Probabilities of each item to be sampled
149+
"""
150+
log_indices = tf.math.log(tf.range(1.0, self.max_id - self.min_id + 2.0, 1.0))
151+
sampling_probs = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
152+
153+
if self.unique:
154+
# Below is a more numerically stable implementation of the probability of
155+
# sampling an item at least once (suitable for sampling unique items)
156+
# P(item is sampled at least once) = 1 - P(item is not sampled)^num_trials
157+
# where P(item is not sampled) = 1-p and p is the
158+
# probability to be sampled
159+
sampling_probs = -tf.math.expm1(self.max_num_samples * tf.math.log1p(-sampling_probs))
160+
161+
# Shifting probs if first values of item id mapping table are reserved
162+
if self.min_id > 0:
163+
sampling_probs = tf.concat(
164+
[tf.zeros([self.min_id], dtype=sampling_probs.dtype), sampling_probs], axis=0
165+
)
166+
167+
sampling_probs = tf.stop_gradient(sampling_probs)
168+
169+
return sampling_probs
170+
171+
def with_sampling_probs(self, items: Candidate) -> Candidate:
172+
"""Returns a copy of the Candidate named tuple with
173+
the sampling_probs set,
174+
175+
Parameters
176+
----------
177+
items : Candidate
178+
Positive or negative candidate items
179+
180+
Returns
181+
-------
182+
Candidate
183+
Candidate items with sampling probability set
184+
"""
185+
sampling_probs = tf.gather(self.sampling_dist, items.id)
186+
items_with_sampling_prob = items.with_sampling_prob(sampling_probs)
187+
return items_with_sampling_prob
188+
110189
def get_config(self):
111190
config = super().get_config()
112191
config["max_id"] = self.max_id

merlin/models/tf/transforms/bias.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ class PopularityLogitsCorrection(Block):
8282
where `item_prob = item_freq_count / sum(item_freq_count)` is
8383
a probability distribution of the item frequency. In a nutshell,
8484
the logQ correction aims to increase the prediction scores (logits)
85-
for infrequent items and decrease the ones for frequent items.
85+
for infrequent items and decrease the ones for frequent items, so
86+
that they are not much more penalized for being sampled more often.
8687
8788
References
8889
----------

tests/unit/tf/outputs/test_contrastive.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_two_tower_constrastive_with_logq_correction(ecommerce_data: Dataset):
101101

102102

103103
@pytest.mark.parametrize("run_eagerly", [True, False])
104-
def test_contrastive_output(ecommerce_data: Dataset, run_eagerly):
104+
def test_contrastive_output_with_sampled_softmax(ecommerce_data: Dataset, run_eagerly):
105105
schema = ecommerce_data.schema
106106
schema["item_category"] = schema["item_category"].with_tags(
107107
schema["item_category"].tags + "target"
@@ -112,7 +112,8 @@ def test_contrastive_output(ecommerce_data: Dataset, run_eagerly):
112112
mm.MLPBlock([8]),
113113
mm.ContrastiveOutput(
114114
schema["item_category"],
115-
negative_samplers=PopularityBasedSamplerV2(max_id=100, max_num_samples=20),
115+
negative_samplers=PopularityBasedSamplerV2(max_id=100, max_num_samples=20, min_id=1),
116+
logq_sampling_correction=True,
116117
),
117118
)
118119

0 commit comments

Comments
 (0)