Skip to content

Commit a1d0be2

Browse files
Makes RetrievalModelV2 support item tower with transforms (e.g. pre-trained embeddings) (#1198)
* Making retrieval model to_top_k_model(), candidate_embeddings() and batch_predict() support Loader with transforms for pre-trained embeddings in item tower * Fixing test error and ensuring all batch_predict() with the new API support Loader with transforms (which include pre-trained embeddings) * Fixing retrieval example, which was using wrong schema to export query and item embeddings * Added missing importorskip on torch and pytorch_lightning for torch integration tests * Skiping a test if nvtabular is available
1 parent 52c89a4 commit a1d0be2

File tree

6 files changed

+160
-27
lines changed

6 files changed

+160
-27
lines changed

examples/05-Retrieval-Model.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,7 +1616,8 @@
16161616
}
16171617
],
16181618
"source": [
1619-
"queries = model.query_embeddings(Dataset(user_features, schema=schema), batch_size=1024, index=Tags.USER_ID)\n",
1619+
"queries = model.query_embeddings(Dataset(user_features, schema=schema.select_by_tag(Tags.USER)), \n",
1620+
" batch_size=1024, index=Tags.USER_ID)\n",
16201621
"query_embs_df = queries.compute(scheduler=\"synchronous\").reset_index()"
16211622
]
16221623
},
@@ -1996,7 +1997,8 @@
19961997
}
19971998
],
19981999
"source": [
1999-
"item_embs = model.candidate_embeddings(Dataset(item_features, schema=schema), batch_size=1024, index=Tags.ITEM_ID)"
2000+
"item_embs = model.candidate_embeddings(Dataset(item_features, schema=schema.select_by_tag(Tags.ITEM)), \n",
2001+
" batch_size=1024, index=Tags.ITEM_ID)"
20002002
]
20012003
},
20022004
{
@@ -2460,7 +2462,7 @@
24602462
"name": "python",
24612463
"nbconvert_exporter": "python",
24622464
"pygments_lexer": "ipython3",
2463-
"version": "3.10.8"
2465+
"version": "3.8.10"
24642466
},
24652467
"merlin": {
24662468
"containers": [

merlin/models/tf/core/encoder.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from merlin.models.tf.core.prediction import TopKPrediction
2828
from merlin.models.tf.inputs.base import InputBlockV2
2929
from merlin.models.tf.inputs.embedding import CombinerType, EmbeddingTable
30+
from merlin.models.tf.loader import Loader
3031
from merlin.models.tf.models.base import BaseModel, get_output_schema
3132
from merlin.models.tf.outputs.topk import TopKOutput
3233
from merlin.models.tf.transforms.features import PrepareFeatures
@@ -84,7 +85,7 @@ def __init__(
8485

8586
def encode(
8687
self,
87-
dataset: merlin.io.Dataset,
88+
dataset: Union[merlin.io.Dataset, Loader],
8889
index: Union[str, ColumnSchema, Schema, Tags],
8990
batch_size: int,
9091
**kwargs,
@@ -93,7 +94,7 @@ def encode(
9394
9495
Parameters
9596
----------
96-
dataset: merlin.io.Dataset
97+
dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
9798
The dataset to encode.
9899
index: Union[str, ColumnSchema, Schema, Tags]
99100
The index to use for encoding.
@@ -127,7 +128,7 @@ def encode(
127128

128129
def batch_predict(
129130
self,
130-
dataset: merlin.io.Dataset,
131+
dataset: Union[merlin.io.Dataset, Loader],
131132
batch_size: int,
132133
output_schema: Optional[Schema] = None,
133134
index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
@@ -137,8 +138,8 @@ def batch_predict(
137138
138139
Parameters
139140
----------
140-
dataset: merlin.io.Dataset
141-
Dataset to predict on.
141+
dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
142+
Dataset or Loader to predict on.
142143
batch_size: int
143144
Batch size to use for prediction.
144145
@@ -161,18 +162,35 @@ def batch_predict(
161162
raise ValueError("Only one column can be used as index")
162163
index = index.first.name
163164

165+
dataset_schema = None
164166
if hasattr(dataset, "schema"):
165-
if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)):
167+
dataset_schema = dataset.schema
168+
data_output_schema = dataset_schema
169+
if isinstance(dataset, Loader):
170+
data_output_schema = dataset.output_schema
171+
if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)):
166172
raise ValueError(
167173
f"Model schema {self.schema.column_names} does not match dataset schema"
168-
+ f" {dataset.schema.column_names}"
174+
+ f" {data_output_schema.column_names}"
169175
)
170176

177+
loader_transforms = None
178+
if isinstance(dataset, Loader):
179+
loader_transforms = dataset.transforms
180+
batch_size = dataset.batch_size
181+
dataset = dataset.dataset
182+
171183
# Check if merlin-dataset is passed
172184
if hasattr(dataset, "to_ddf"):
173185
dataset = dataset.to_ddf()
174186

175-
model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
187+
model_encode = TFModelEncode(
188+
self,
189+
batch_size=batch_size,
190+
loader_transforms=loader_transforms,
191+
schema=dataset_schema,
192+
**kwargs,
193+
)
176194

177195
encode_kwargs = {}
178196
if output_schema:
@@ -583,7 +601,7 @@ def encode_candidates(
583601

584602
def batch_predict(
585603
self,
586-
dataset: merlin.io.Dataset,
604+
dataset: Union[merlin.io.Dataset, Loader],
587605
batch_size: int,
588606
output_schema: Optional[Schema] = None,
589607
**kwargs,
@@ -592,8 +610,8 @@ def batch_predict(
592610
593611
Parameters
594612
----------
595-
dataset : merlin.io.Dataset
596-
Raw queries features dataset
613+
dataset : Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
614+
Raw queries features dataset or Loader
597615
batch_size : int
598616
The number of queries to process at each prediction step
599617
output_schema: Schema, optional
@@ -606,15 +624,24 @@ def batch_predict(
606624
"""
607625
from merlin.models.tf.utils.batch_utils import TFModelEncode
608626

627+
loader_transforms = None
628+
if isinstance(dataset, Loader):
629+
loader_transforms = dataset.transforms
630+
batch_size = dataset.batch_size
631+
dataset = dataset.dataset
632+
633+
dataset_schema = dataset.schema
634+
dataset = dataset.to_ddf()
635+
609636
model_encode = TFModelEncode(
610637
model=self,
611638
batch_size=batch_size,
639+
loader_transforms=loader_transforms,
640+
schema=dataset_schema,
612641
output_names=TopKPrediction.output_names(self.k),
613642
**kwargs,
614643
)
615644

616-
dataset = dataset.to_ddf()
617-
618645
encode_kwargs = {}
619646
if output_schema:
620647
encode_kwargs["filter_input_columns"] = output_schema.column_names

merlin/models/tf/models/base.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,7 +1553,7 @@ def predict(
15531553
return out
15541554

15551555
def batch_predict(
1556-
self, dataset: merlin.io.Dataset, batch_size: int, **kwargs
1556+
self, dataset: Union[merlin.io.Dataset, Loader], batch_size: int, **kwargs
15571557
) -> merlin.io.Dataset:
15581558
"""Batched prediction using the Dask.
15591559
Parameters
@@ -1565,20 +1565,38 @@ def batch_predict(
15651565
Returns merlin.io.Dataset
15661566
-------
15671567
"""
1568+
dataset_schema = None
15681569
if hasattr(dataset, "schema"):
1569-
if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)):
1570+
dataset_schema = dataset.schema
1571+
data_output_schema = dataset_schema
1572+
if isinstance(dataset, Loader):
1573+
data_output_schema = dataset.output_schema
1574+
1575+
if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)):
15701576
raise ValueError(
15711577
f"Model schema {self.schema.column_names} does not match dataset schema"
1572-
+ f" {dataset.schema.column_names}"
1578+
+ f" {data_output_schema.column_names}"
15731579
)
15741580

1581+
loader_transforms = None
1582+
if isinstance(dataset, Loader):
1583+
loader_transforms = dataset.transforms
1584+
batch_size = dataset.batch_size
1585+
dataset = dataset.dataset
1586+
15751587
# Check if merlin-dataset is passed
15761588
if hasattr(dataset, "to_ddf"):
15771589
dataset = dataset.to_ddf()
15781590

15791591
from merlin.models.tf.utils.batch_utils import TFModelEncode
15801592

1581-
model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
1593+
model_encode = TFModelEncode(
1594+
self,
1595+
batch_size=batch_size,
1596+
loader_transforms=loader_transforms,
1597+
schema=dataset_schema,
1598+
**kwargs,
1599+
)
15821600

15831601
# Processing a sample of the dataset with the model encoder
15841602
# to get the output dataframe dtypes
@@ -2510,20 +2528,20 @@ def query_embeddings(
25102528

25112529
def candidate_embeddings(
25122530
self,
2513-
dataset: Optional[merlin.io.Dataset] = None,
2531+
data: Optional[Union[merlin.io.Dataset, Loader]] = None,
25142532
index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
25152533
**kwargs,
25162534
) -> merlin.io.Dataset:
25172535
if self.has_candidate_encoder:
25182536
candidate = self.candidate_encoder
25192537

2520-
if dataset is not None and hasattr(candidate, "encode"):
2521-
return candidate.encode(dataset, index=index, **kwargs)
2538+
if data is not None and hasattr(candidate, "encode"):
2539+
return candidate.encode(data, index=index, **kwargs)
25222540

25232541
if hasattr(candidate, "to_dataset"):
25242542
return candidate.to_dataset(**kwargs)
25252543

2526-
return candidate.encode(dataset, index=index, **kwargs)
2544+
return candidate.encode(data, index=index, **kwargs)
25272545

25282546
if isinstance(self.last, (ContrastiveOutput, CategoricalOutput)):
25292547
return self.last.to_dataset()

merlin/models/tf/utils/batch_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
block_load_func: tp.Optional[tp.Callable[[str], Block]] = None,
7575
schema: tp.Optional[Schema] = None,
7676
output_concat_func=None,
77+
loader_transforms=None,
7778
):
7879
save_path = save_path or tempfile.mkdtemp()
7980
model.save(save_path)
@@ -95,7 +96,9 @@ def __init__(
9596
super().__init__(
9697
save_path,
9798
output_names,
98-
data_iterator_func=data_iterator_func(self.schema, batch_size=batch_size),
99+
data_iterator_func=data_iterator_func(
100+
self.schema, batch_size=batch_size, loader_transforms=loader_transforms
101+
),
99102
model_load_func=model_load_func,
100103
model_encode_func=model_encode,
101104
output_concat_func=output_concat_func,
@@ -172,14 +175,15 @@ def encode_output(output: tf.Tensor):
172175
return output.numpy()
173176

174177

175-
def data_iterator_func(schema, batch_size: int = 512):
178+
def data_iterator_func(schema, batch_size: int = 512, loader_transforms=None):
176179
import merlin.io.dataset
177180

178181
def data_iterator(dataset):
179182
return Loader(
180183
merlin.io.dataset.Dataset(dataset, schema=schema),
181184
batch_size=batch_size,
182185
shuffle=False,
186+
transforms=loader_transforms,
183187
)
184188

185189
return data_iterator
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#
2+
# Copyright (c) 2021, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
import pytest
17+
18+
pytest.importorskip("torch")
19+
pytest.importorskip("pytorch_lightning")

tests/unit/tf/models/test_retrieval.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from pathlib import Path
22

3-
import nvtabular as nvt
3+
import numpy as np
44
import pytest
55
import tensorflow as tf
66

77
import merlin.models.tf as mm
88
from merlin.core.dispatch import make_df
9+
from merlin.dataloader.ops.embeddings import EmbeddingOperator
910
from merlin.io import Dataset
1011
from merlin.models.tf.metrics.topk import (
1112
AvgPrecisionAt,
@@ -24,6 +25,8 @@
2425

2526

2627
def test_two_tower_shared_embeddings():
28+
nvt = pytest.importorskip("nvtabular")
29+
2730
train = make_df(
2831
{
2932
"user_id": [1, 3, 3, 4, 3, 1, 2, 4, 6, 7, 8, 9] * 100,
@@ -435,6 +438,66 @@ def test_two_tower_model_topk_evaluation(ecommerce_data: Dataset, run_eagerly):
435438
assert all([metric >= 0 for metric in metrics.values()])
436439

437440

441+
@pytest.mark.parametrize("run_eagerly", [True, False])
442+
def test_two_tower_model_topk_evaluation_with_pretrained_emb(music_streaming_data, run_eagerly):
443+
music_streaming_data.schema = music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM])
444+
445+
cardinality = music_streaming_data.schema["item_category"].int_domain.max + 1
446+
pretrained_embedding = np.random.rand(cardinality, 12)
447+
448+
loader_transforms = [
449+
EmbeddingOperator(
450+
pretrained_embedding,
451+
lookup_key="item_category",
452+
embedding_name="pretrained_category_embeddings",
453+
),
454+
]
455+
loader = mm.Loader(
456+
music_streaming_data,
457+
schema=music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM]),
458+
batch_size=10,
459+
transforms=loader_transforms,
460+
)
461+
schema = loader.output_schema
462+
463+
pretrained_embeddings = mm.PretrainedEmbeddings(
464+
schema.select_by_tag(Tags.EMBEDDING),
465+
output_dims=16,
466+
)
467+
468+
schema = loader.output_schema
469+
470+
query_input = mm.InputBlockV2(schema.select_by_tag(Tags.USER))
471+
query = mm.Encoder(query_input, mm.MLPBlock([4], no_activation_last_layer=True))
472+
candidate_input = mm.InputBlockV2(
473+
schema.select_by_tag(Tags.ITEM), pretrained_embeddings=pretrained_embeddings
474+
)
475+
candidate = mm.Encoder(candidate_input, mm.MLPBlock([4], no_activation_last_layer=True))
476+
model = mm.TwoTowerModelV2(
477+
query,
478+
candidate,
479+
negative_samplers=["in-batch"],
480+
)
481+
model.compile(optimizer="adam", run_eagerly=run_eagerly)
482+
_ = testing_utils.model_test(model, loader)
483+
484+
# Top-K evaluation
485+
candidate_features_data = unique_rows_by_features(music_streaming_data, Tags.ITEM, Tags.ITEM_ID)
486+
loader_candidates = mm.Loader(
487+
candidate_features_data,
488+
batch_size=16,
489+
transforms=loader_transforms,
490+
)
491+
492+
topk_model = model.to_top_k_encoder(loader_candidates, k=20, batch_size=16)
493+
topk_model.compile(run_eagerly=run_eagerly)
494+
495+
loader = mm.Loader(music_streaming_data, batch_size=32).map(mm.ToTarget(schema, "item_id"))
496+
497+
metrics = topk_model.evaluate(loader, return_dict=True)
498+
assert all([metric >= 0 for metric in metrics.values()])
499+
500+
438501
@pytest.mark.parametrize("run_eagerly", [True, False])
439502
@pytest.mark.parametrize("logits_pop_logq_correction", [True, False])
440503
@pytest.mark.parametrize("loss", ["categorical_crossentropy", "bpr-max", "binary_crossentropy"])

0 commit comments

Comments
 (0)