-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathtest_serving_ranking_models_with_merlin_systems.py
More file actions
116 lines (103 loc) · 4.19 KB
/
test_serving_ranking_models_with_merlin_systems.py
File metadata and controls
116 lines (103 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import pytest
from testbook import testbook
from tests.conftest import REPO_ROOT
pytest.importorskip("cudf")
pytest.importorskip("tensorflow")
pytest.importorskip("merlin.models")
@pytest.mark.notebook
@testbook(REPO_ROOT / "examples/Serving-Ranking-Models-With-Merlin-Systems.ipynb", execute=False)
def test_example_04_exporting_ranking_models(tb):
import numpy as np
import tensorflow as tf
import merlin.models.tf as mm
import nvtabular as nvt
from merlin.datasets.synthetic import generate_data
from merlin.io.dataset import Dataset
from merlin.schema import Schema, Tags
DATA_FOLDER = "/tmp/data/"
NUM_ROWS = 1000000
BATCH_SIZE = 512
train, valid = generate_data("aliccp-raw", int(NUM_ROWS), set_sizes=(0.7, 0.3))
train.to_ddf().to_parquet(os.path.join(DATA_FOLDER, "train"))
valid.to_ddf().to_parquet(os.path.join(DATA_FOLDER, "valid"))
train_path = os.path.join(DATA_FOLDER, "train", "*.parquet")
valid_path = os.path.join(DATA_FOLDER, "valid", "*.parquet")
output_path = os.path.join(DATA_FOLDER, "processed")
user_id = ["user_id"] >> nvt.ops.Categorify() >> nvt.ops.TagAsUserID()
item_id = ["item_id"] >> nvt.ops.Categorify() >> nvt.ops.TagAsItemID()
targets = ["click"] >> nvt.ops.AddMetadata(tags=[Tags.BINARY_CLASSIFICATION, "target"])
item_features = (
["item_category", "item_shop", "item_brand"]
>> nvt.ops.Categorify()
>> nvt.ops.TagAsItemFeatures()
)
user_features = (
[
"user_shops",
"user_profile",
"user_group",
"user_gender",
"user_age",
"user_consumption_2",
"user_is_occupied",
"user_geography",
"user_intentions",
"user_brands",
"user_categories",
]
>> nvt.ops.Categorify()
>> nvt.ops.TagAsUserFeatures()
)
outputs = user_id + item_id + item_features + user_features + targets
workflow = nvt.Workflow(outputs)
train_dataset = nvt.Dataset(train_path)
valid_dataset = nvt.Dataset(valid_path)
workflow.fit(train_dataset)
workflow.transform(train_dataset).to_parquet(output_path=output_path + "/train/")
workflow.transform(valid_dataset).to_parquet(output_path=output_path + "/valid/")
workflow.save("/tmp/data/workflow")
train = Dataset(os.path.join(output_path, "train", "*.parquet"))
valid = Dataset(os.path.join(output_path, "valid", "*.parquet"))
schema = train.schema
target_column = schema.select_by_tag(Tags.TARGET).column_names[0]
model = mm.DLRMModel(
schema,
embedding_dim=64,
bottom_block=mm.MLPBlock([128, 64]),
top_block=mm.MLPBlock([128, 64, 32]),
prediction_tasks=mm.BinaryClassificationTask(target_column),
)
model.compile("adam", run_eagerly=False, metrics=[tf.keras.metrics.AUC()])
model.fit(train, validation_data=valid, batch_size=BATCH_SIZE)
model.save("/tmp/data/dlrm")
tb.inject(
"""
import os
os.environ["INPUT_FOLDER"] = "/tmp/data/"
"""
)
NUM_OF_CELLS = len(tb.cells)
tb.execute_cell(list(range(0, NUM_OF_CELLS - 12)))
tb.execute_cell(list(range(NUM_OF_CELLS - 9, NUM_OF_CELLS - 6)))
from merlin.core.dispatch import get_lib
df_lib = get_lib()
# original_data_path = os.environ.get("INPUT_FOLDER", "/workspace/data/")
# read in data for request
batch = df_lib.read_parquet(
os.path.join("/tmp/data/", "valid", "part.0.parquet"),
columns=workflow.input_schema.column_names,
).head(3)
batch = batch.drop(columns="click")
outputs = tb.ref("output_cols")
from merlin.dataloader.tf_utils import configure_tensorflow
configure_tensorflow()
from merlin.systems.triton.utils import run_ensemble_on_tritonserver
# The schema contains int64s, while the actual data contains int32s. Not sure why.
schema = Schema(
[col_schema.with_dtype(np.int32) for col_schema in schema.column_schemas.values()]
)
response = run_ensemble_on_tritonserver(
"/tmp/data/ensemble/", schema.without(["click"]), batch, outputs, "executor_model"
)
assert len(response["click/binary_classification_task"]) == 3