Skip to content

Commit 3ee0694

Browse files
sararbrnyakjperez999
authored
Add pre-trained embeddings support to T4Rec input blocks (#690)
* add support of pretrained embeddings to t4rec input blocks -- v0 * add sequence_combiner * fix merge conflict * add pre-trained op to t4rec trainer * remove custom padding from T4rec dataloader * revert back custom padding * Add tabular-normalization * fix linting * replace value_count with is_list to check if a feature is a list column * extend support of 2-D pretrained embeddings * update docstrings * fix support of 2d/3d pretrained embeddings * try to fix failing test in gpu-ci * adding ops-bot yaml file (#708) * try to fix torchscript error raised by gpu-ci --------- Co-authored-by: rnyak <[email protected]> Co-authored-by: Julio Perez <[email protected]>
1 parent 794365c commit 3ee0694

File tree

9 files changed

+510
-43
lines changed

9 files changed

+510
-43
lines changed

tests/unit/torch/features/test_sequential.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414
# limitations under the License.
1515
#
1616

17+
import numpy as np
1718
import pytest
19+
from merlin.dataloader.ops.embeddings import EmbeddingOperator
20+
from merlin.io import Dataset
21+
from merlin.schema import ColumnSchema
1822
from merlin.schema import Schema as CoreSchema
1923
from merlin.schema import Tags
2024

2125
import transformers4rec.torch as tr
2226
from tests.conftest import parametrize_schemas
27+
from transformers4rec.torch.utils.data_utils import MerlinDataLoader
2328

2429

2530
@parametrize_schemas("yoochoose")
@@ -217,3 +222,173 @@ def test_sequential_and_non_sequential_tabular_features(schema, torch_yoochoose_
217222
outputs = tab_module(torch_yoochoose_like)
218223

219224
assert list(outputs.shape) == [100, 20, 203]
225+
226+
227+
@pytest.mark.parametrize(
228+
"pretrained_dim",
229+
[None, 128, {"pretrained_item_id_embeddings": 128, "pretrained_user_id_embeddings": 128}],
230+
)
231+
def test_sequential_input_block_with_pretrained_embeddings(pretrained_dim):
232+
data = tr.data.music_streaming_testing_data
233+
seq_schema = data.merlin_schema.select_by_name(["item_id"])
234+
# Set the property `dims` for the non-sequential feature: "user_id"
235+
user_cardinality = data.merlin_schema["user_id"].int_domain.max + 1
236+
seq_schema = seq_schema + CoreSchema(
237+
[
238+
ColumnSchema(
239+
"user_id",
240+
dtype=np.int32,
241+
tags=[Tags.USER, Tags.CATEGORICAL],
242+
properties={
243+
"domain": {"name": "user_id", "min": 0, "max": user_cardinality},
244+
},
245+
dims=(None,),
246+
)
247+
]
248+
)
249+
batch_size, max_length = 128, 20
250+
embedding_dim_default, item_dim, user_dim = 8, 32, 16
251+
252+
# generate pre-trained embeddings tables
253+
item_cardinality = seq_schema["item_id"].int_domain.max + 1
254+
np_emb_item_id = np.random.rand(item_cardinality, item_dim)
255+
np_emb_user_id = np.random.rand(user_cardinality, user_dim)
256+
embeddings_op_item = EmbeddingOperator(
257+
np_emb_item_id, lookup_key="item_id", embedding_name="pretrained_item_id_embeddings"
258+
)
259+
embeddings_op_user = EmbeddingOperator(
260+
np_emb_user_id, lookup_key="user_id", embedding_name="pretrained_user_id_embeddings"
261+
)
262+
263+
# set dataloader with pre-trained embeddings
264+
data_loader = MerlinDataLoader.from_schema(
265+
seq_schema,
266+
data.path,
267+
batch_size=batch_size,
268+
max_sequence_length=max_length,
269+
transforms=[embeddings_op_item, embeddings_op_user],
270+
shuffle=False,
271+
)
272+
273+
batch, _ = next(iter(data_loader))
274+
275+
# Sequential input block with pre-trained features
276+
inputs = tr.TabularSequenceFeatures.from_schema(
277+
data_loader.output_schema,
278+
max_sequence_length=20,
279+
pretrained_output_dims=pretrained_dim,
280+
aggregation=None,
281+
)
282+
283+
# Sequential input + concat aggregation, which inherently performs broadcasting of 2-D features.
284+
inputs_with_concat = tr.TabularSequenceFeatures.from_schema(
285+
data_loader.output_schema,
286+
embedding_dim_default=embedding_dim_default,
287+
max_sequence_length=20,
288+
aggregation="concat",
289+
)
290+
291+
output = inputs.to(batch["item_id"].device).double()(batch)
292+
concat_output = inputs_with_concat.to(batch["item_id"].device).double()(batch)
293+
294+
assert concat_output.shape[-1] == embedding_dim_default * 2 + item_dim + user_dim
295+
296+
assert "pretrained_item_id_embeddings" in output
297+
if pretrained_dim is not None:
298+
assert list(output["pretrained_item_id_embeddings"].shape) == [
299+
batch_size,
300+
max_length,
301+
128,
302+
]
303+
assert list(output["pretrained_user_id_embeddings"].shape) == [
304+
batch_size,
305+
128,
306+
]
307+
else:
308+
assert list(output["pretrained_item_id_embeddings"].shape) == [
309+
batch_size,
310+
max_length,
311+
item_dim,
312+
]
313+
assert list(output["pretrained_user_id_embeddings"].shape) == [
314+
batch_size,
315+
user_dim,
316+
]
317+
318+
319+
@pytest.mark.parametrize(
320+
"pretrained_dim",
321+
[None, 128, {"pretrained_item_id_embeddings": 128, "pretrained_user_id_embeddings": 128}],
322+
)
323+
def test_non_sequential_input_block_with_pretrained_embeddings(pretrained_dim):
324+
data = tr.data.music_streaming_testing_data
325+
seq_schema = data.merlin_schema.select_by_name(["item_id"])
326+
# Set the property `dims` for the non-sequential feature: "user_id"
327+
user_cardinality = data.merlin_schema["user_id"].int_domain.max + 1
328+
seq_schema = seq_schema + CoreSchema(
329+
[
330+
ColumnSchema(
331+
"user_id",
332+
dtype=np.int32,
333+
tags=[Tags.USER, Tags.CATEGORICAL],
334+
properties={
335+
"domain": {"name": "user_id", "min": 0, "max": user_cardinality},
336+
},
337+
dims=(None,),
338+
)
339+
]
340+
)
341+
batch_size, max_length = 128, 20
342+
item_dim, user_dim = 32, 16
343+
344+
# generate pre-trained embeddings tables
345+
item_cardinality = seq_schema["item_id"].int_domain.max + 1
346+
np_emb_item_id = np.random.rand(item_cardinality, item_dim)
347+
np_emb_user_id = np.random.rand(user_cardinality, user_dim)
348+
embeddings_op_item = EmbeddingOperator(
349+
np_emb_item_id, lookup_key="item_id", embedding_name="pretrained_item_id_embeddings"
350+
)
351+
embeddings_op_user = EmbeddingOperator(
352+
np_emb_user_id, lookup_key="user_id", embedding_name="pretrained_user_id_embeddings"
353+
)
354+
355+
# set dataloader with pre-trained embeddings
356+
data_loader = MerlinDataLoader.from_schema(
357+
seq_schema,
358+
Dataset(data.path, schema=seq_schema),
359+
batch_size=batch_size,
360+
max_sequence_length=max_length,
361+
transforms=[embeddings_op_item, embeddings_op_user],
362+
shuffle=False,
363+
)
364+
365+
batch, _ = next(iter(data_loader))
366+
367+
# Non-Sequential input block with a 3-D pre-trained feature
368+
inputs = tr.TabularFeatures.from_schema(
369+
data_loader.output_schema,
370+
pretrained_output_dims=pretrained_dim,
371+
sequence_combiner="mean",
372+
aggregation=None,
373+
)
374+
output = inputs.to(batch["item_id"].device).double()(batch)
375+
376+
assert "pretrained_item_id_embeddings" in output
377+
if pretrained_dim is not None:
378+
assert list(output["pretrained_item_id_embeddings"].shape) == [
379+
batch_size,
380+
128,
381+
]
382+
assert list(output["pretrained_user_id_embeddings"].shape) == [
383+
batch_size,
384+
128,
385+
]
386+
else:
387+
assert list(output["pretrained_item_id_embeddings"].shape) == [
388+
batch_size,
389+
item_dim,
390+
]
391+
assert list(output["pretrained_user_id_embeddings"].shape) == [
392+
batch_size,
393+
user_dim,
394+
]

tests/unit/torch/test_trainer.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,75 @@ def test_trainer_trop_k_with_wrong_task():
598598
recsys_trainer.predict(data.path)
599599

600600
assert "Top-k prediction is specific to NextItemPredictionTask" in str(excinfo.value)
601+
602+
603+
def test_trainer_with_pretrained_embeddings():
604+
import numpy as np
605+
from merlin.dataloader.ops.embeddings import EmbeddingOperator
606+
from merlin.io import Dataset
607+
608+
from transformers4rec.torch.utils.data_utils import MerlinDataLoader
609+
610+
data = tr.data.music_streaming_testing_data
611+
schema = data.merlin_schema.select_by_name(
612+
["item_id", "item_category", "item_recency", "item_genres", "user_id"]
613+
)
614+
batch_size, max_length, pretrained_dim = 128, 20, 16
615+
item_cardinality = schema["item_id"].int_domain.max + 1
616+
np_emb_item_id = np.random.rand(item_cardinality, pretrained_dim)
617+
618+
embeddings_op = EmbeddingOperator(
619+
np_emb_item_id, lookup_key="item_id", embedding_name="pretrained_item_id_embeddings"
620+
)
621+
# set dataloader with pre-trained embeddings
622+
data_loader = MerlinDataLoader.from_schema(
623+
schema,
624+
Dataset(data.path, schema=schema),
625+
max_sequence_length=max_length,
626+
batch_size=batch_size,
627+
transforms=[embeddings_op],
628+
shuffle=False,
629+
)
630+
631+
# set the model schema from data-loader
632+
model_schema = data_loader.output_schema
633+
inputs = tr.TabularSequenceFeatures.from_schema(
634+
model_schema,
635+
max_sequence_length=max_length,
636+
pretrained_output_dims=8,
637+
normalizer="layer-norm",
638+
d_output=64,
639+
masking="mlm",
640+
)
641+
transformer_config = tconf.XLNetConfig.build(64, 4, 2, 20)
642+
task = tr.NextItemPredictionTask(weight_tying=True)
643+
model = transformer_config.to_torch_model(inputs, task, max_sequence_length=max_length)
644+
645+
assert isinstance(model.input_schema, Schema)
646+
647+
args = trainer.T4RecTrainingArguments(
648+
output_dir=".",
649+
max_steps=5,
650+
num_train_epochs=1,
651+
per_device_train_batch_size=batch_size,
652+
per_device_eval_batch_size=batch_size // 2,
653+
max_sequence_length=max_length,
654+
fp16=False,
655+
report_to=[],
656+
debug=["r"],
657+
)
658+
# Explicitly pass the merlin dataloader with pre-trained embeddings
659+
recsys_trainer = tr.Trainer(
660+
model=model,
661+
args=args,
662+
schema=schema,
663+
train_dataloader=data_loader,
664+
eval_dataloader=data_loader,
665+
compute_metrics=True,
666+
)
667+
668+
recsys_trainer.train()
669+
eval_metrics = recsys_trainer.evaluate(eval_dataset=data.path, metric_key_prefix="eval")
670+
671+
assert isinstance(eval_metrics, dict)
672+
assert eval_metrics["eval_/loss"] is not None

transformers4rec/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .features.embedding import (
4141
EmbeddingFeatures,
4242
FeatureConfig,
43+
PretrainedEmbeddingFeatures,
4344
PretrainedEmbeddingsInitializer,
4445
SoftEmbedding,
4546
SoftEmbeddingFeatures,
@@ -104,6 +105,7 @@
104105
"EmbeddingFeatures",
105106
"SoftEmbeddingFeatures",
106107
"PretrainedEmbeddingsInitializer",
108+
"PretrainedEmbeddingFeatures",
107109
"TabularSequenceFeatures",
108110
"SequenceEmbeddingFeatures",
109111
"FeatureConfig",

0 commit comments

Comments
 (0)