Skip to content

Commit 53e3533

Browse files
committed
Consolidate Embedding Operator and add offsets to embeddings (#138)
* consolidate embedding operator * add offsets to embeddings * fix schema shape calculation and clean up tests
1 parent ef229a4 commit 53e3533

File tree

2 files changed

+37
-209
lines changed

2 files changed

+37
-209
lines changed

merlin/dataloader/ops/embeddings.py

Lines changed: 11 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
16+
from typing import Optional, Union
1717

1818
import numpy as np
1919

@@ -43,92 +43,15 @@ class EmbeddingOperator(BaseOperator):
4343

4444
def __init__(
4545
self,
46-
embeddings: np.ndarray,
46+
embeddings: Union[np.ndarray, str],
4747
lookup_key: str = "id",
4848
embedding_name: str = "embeddings",
49-
id_lookup_table=None,
50-
):
51-
self.embeddings = embeddings
52-
self.lookup_key = lookup_key
53-
self.embedding_name = embedding_name
54-
self.id_lookup_table = id_lookup_table
55-
56-
def transform(
57-
self, col_selector: ColumnSelector, transformable: Transformable
58-
) -> Transformable:
59-
keys = transformable[self.lookup_key]
60-
indices = keys.cpu().values
61-
if self.id_lookup_table is not None:
62-
indices = np.nonzero(np.in1d(self.id_lookup_table, indices))
63-
embeddings = self.embeddings[indices]
64-
embeddings_col = TensorColumn(embeddings)
65-
transformable[self.embedding_name] = (
66-
embeddings_col.gpu() if keys.device == Device.GPU else embeddings_col
67-
)
68-
return transformable
69-
70-
def compute_output_schema(
71-
self,
72-
input_schema: Schema,
73-
col_selector: ColumnSelector,
74-
prev_output_schema: Schema = None,
75-
) -> Schema:
76-
"""Creates the output schema for this operator.
77-
78-
Parameters
79-
----------
80-
input_schema : Schema
81-
schema coming from ancestor nodes
82-
col_selector : ColumnSelector
83-
subselection of columns to apply to this operator
84-
prev_output_schema : Schema, optional
85-
the output schema of the previously executed operators, by default None
86-
87-
Returns
88-
-------
89-
Schema
90-
Schema representing the correct output for this operator.
91-
"""
92-
col_schemas = []
93-
for _, col_schema in input_schema.column_schemas.items():
94-
col_schemas.append(col_schema)
95-
col_schemas.append(
96-
ColumnSchema(
97-
name=self.embedding_name,
98-
tags=[Tags.CONTINUOUS, Tags.EMBEDDING],
99-
dtype=self.embeddings.dtype,
100-
is_list=True,
101-
is_ragged=False,
102-
)
103-
)
104-
105-
return Schema(col_schemas)
106-
107-
108-
class NumpyEmbeddingOperator(BaseOperator):
109-
"""Create an embedding table from supplied embeddings to add embedding entry
110-
to records based on supplied indices. Support for indices lookup table is available.
111-
Embedding table is stored in host memory.
112-
113-
Parameters
114-
----------
115-
embeddings : np.ndarray
116-
numpy ndarray representing embedding values
117-
lookup_key : str, optional
118-
the name of the column that will be used as indices, by default "id"
119-
embedding_name : str, optional
120-
name of new column of embeddings, added to output, by default "embeddings"
121-
id_lookup_table : np.array, optional
122-
numpy array of values that represent embedding indices, by default None
123-
"""
124-
125-
def __init__(
126-
self,
127-
embeddings: np.ndarray,
128-
lookup_key: str = "id",
129-
embedding_name: str = "embeddings",
130-
id_lookup_table=None,
49+
id_lookup_table: Optional[Union[np.ndarray, str]] = None,
50+
mmap=False,
13151
):
52+
if mmap:
53+
embeddings = np.load(embeddings, mmap_mode="r")
54+
id_lookup_table = np.load(id_lookup_table) if id_lookup_table else None
13255
self.embeddings = embeddings
13356
self.lookup_key = lookup_key
13457
self.embedding_name = embedding_name
@@ -142,16 +65,12 @@ def transform(
14265
if self.id_lookup_table is not None:
14366
indices = np.in1d(self.id_lookup_table, indices)
14467
embeddings = self.embeddings[indices]
145-
# numpy_to_tensor
146-
embeddings_col = TensorColumn(embeddings)
68+
embeddings_col = TensorColumn(embeddings, offsets=keys.cpu().offsets)
14769
transformable[self.embedding_name] = (
14870
embeddings_col.gpu() if keys.device == Device.GPU else embeddings_col
14971
)
15072
return transformable
15173

152-
def _format_embeddings(self, embeddings, keys):
153-
raise NotImplementedError("No logic to format embeddings.")
154-
15574
def compute_output_schema(
15675
self,
15776
input_schema: Schema,
@@ -177,53 +96,15 @@ def compute_output_schema(
17796
col_schemas = []
17897
for _, col_schema in input_schema.column_schemas.items():
17998
col_schemas.append(col_schema)
99+
id_schema = input_schema.column_schemas[self.lookup_key]
180100
embedding_dim = self.embeddings.shape[1]
181101
col_schemas.append(
182102
ColumnSchema(
183103
name=self.embedding_name,
184-
tags=[Tags.CONTINUOUS, Tags.EMBEDDING],
104+
tags=[Tags.EMBEDDING],
185105
dtype=self.embeddings.dtype,
186-
is_list=True,
187-
is_ragged=False,
188-
properties={"value_count": {"min": embedding_dim, "max": embedding_dim}},
106+
dims=id_schema.shape.as_tuple + (embedding_dim,),
189107
)
190108
)
191109

192110
return Schema(col_schemas)
193-
194-
195-
class MmapNumpyEmbedding(NumpyEmbeddingOperator):
196-
"""Operator loads numpy embedding table from file using memory map to be used to create
197-
torch embedding representations. This allows for larger than host memory embedding
198-
tables to be used for embedding lookups. The only limit to the size is what fits in
199-
storage, preferred storage device is SSD for faster lookups.
200-
201-
Parameters
202-
----------
203-
embedding_npz : numpy ndarray file
204-
file holding numpy ndarray representing embedding table
205-
ids_lookup_npz : numpy array file, optional
206-
file holding numpy array of values that represent embedding indices, by default None
207-
lookup_key : str, optional
208-
the name of the column that will be used as indices, by default "id"
209-
embedding_name : str, optional
210-
name of new column of embeddings, added to output, by default "embeddings"
211-
transform_function : _type_, optional
212-
function that will transform embedding from numpy to torch, by default None
213-
"""
214-
215-
def __init__(
216-
self,
217-
embedding_npz,
218-
ids_lookup_npz=None,
219-
lookup_key="id",
220-
embedding_name="embeddings",
221-
):
222-
embeddings = np.load(embedding_npz, mmap_mode="r")
223-
id_lookup = np.load(ids_lookup_npz) if ids_lookup_npz else None
224-
super().__init__(
225-
embeddings,
226-
lookup_key=lookup_key,
227-
embedding_name=embedding_name,
228-
id_lookup_table=id_lookup,
229-
)

tests/unit/dataloader/test_embeddings.py

Lines changed: 26 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,10 @@
2020

2121
from merlin.core.dispatch import HAS_GPU
2222
from merlin.dataloader.loader_base import LoaderBase as Loader # noqa
23-
from merlin.dataloader.ops.embeddings import ( # noqa
24-
EmbeddingOperator,
25-
MmapNumpyEmbedding,
26-
NumpyEmbeddingOperator,
27-
)
23+
from merlin.dataloader.ops.embeddings import EmbeddingOperator
2824
from merlin.io import Dataset
2925
from merlin.schema import Tags
26+
from merlin.table import TensorColumn, TensorTable
3027

3128

3229
@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"])
@@ -40,17 +37,13 @@ def test_embedding_np_mmap_dl_no_lookup(tmpdir, embedding_ids, np_embeddings_fro
4037
dataset = Dataset(str(pq_path))
4138
dataset = dataset.repartition(10)
4239
schema = dataset.schema
43-
for col_name in cat_names:
44-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
45-
dataset.schema = schema
46-
4740
for col_name in cat_names:
4841
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
4942
dataset.schema = schema
5043
data_loader = Loader(
5144
dataset,
5245
batch_size=batch_size,
53-
transforms=[MmapNumpyEmbedding(embeddings_file)],
46+
transforms=[EmbeddingOperator(embeddings_file, mmap=True)],
5447
shuffle=False,
5548
device=cpu,
5649
)
@@ -90,13 +83,10 @@ def test_embedding_np_mmap_dl_with_lookup(tmpdir, rev_embedding_ids, np_embeddin
9083
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
9184
dataset.schema = schema
9285

93-
for col_name in cat_names:
94-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
95-
dataset.schema = schema
9686
data_loader = Loader(
9787
dataset,
9888
batch_size=batch_size,
99-
transforms=[MmapNumpyEmbedding(embeddings_file, ids_lookup_npz=id_lookup_file)],
89+
transforms=[EmbeddingOperator(embeddings_file, id_lookup_table=id_lookup_file, mmap=True)],
10090
shuffle=False,
10191
device=cpu,
10292
)
@@ -121,10 +111,6 @@ def test_embedding_np_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_datafr
121111
dataset = Dataset(str(pq_path))
122112
dataset = dataset.repartition(10)
123113
schema = dataset.schema
124-
for col_name in cat_names:
125-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
126-
dataset.schema = schema
127-
128114
for col_name in cat_names:
129115
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
130116
dataset.schema = schema
@@ -134,7 +120,7 @@ def test_embedding_np_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_datafr
134120
data_loader = Loader(
135121
dataset,
136122
batch_size=batch_size,
137-
transforms=[NumpyEmbeddingOperator(embeddings_np)],
123+
transforms=[EmbeddingOperator(embeddings_np)],
138124
shuffle=False,
139125
device=cpu,
140126
)
@@ -160,10 +146,6 @@ def test_embedding_np_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_
160146
dataset = Dataset(str(pq_path))
161147
dataset = dataset.repartition(10)
162148
schema = dataset.schema
163-
for col_name in cat_names:
164-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
165-
dataset.schema = schema
166-
167149
for col_name in cat_names:
168150
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
169151
dataset.schema = schema
@@ -173,9 +155,7 @@ def test_embedding_np_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_
173155
data_loader = Loader(
174156
dataset,
175157
batch_size=batch_size,
176-
transforms=[
177-
NumpyEmbeddingOperator(embeddings_np, id_lookup_table=embedding_ids.to_numpy())
178-
],
158+
transforms=[EmbeddingOperator(embeddings_np, id_lookup_table=embedding_ids.to_numpy())],
179159
shuffle=False,
180160
device=cpu,
181161
)
@@ -192,77 +172,44 @@ def test_embedding_np_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_
192172

193173

194174
@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"])
195-
def test_embedding_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_dataframe, cpu):
175+
def test_embedding_np_dl_with_lookup_ragged(
176+
tmpdir, rev_embedding_ids, embeddings_from_dataframe, cpu
177+
):
196178
cat_names = ["id"]
197-
batch_size = 10000
179+
batch_size = 5
198180
pq_path = tmpdir / "id.parquet"
199-
embedding_ids.to_parquet(pq_path)
200-
dataset = Dataset(str(pq_path))
181+
embedding_ids = rev_embedding_ids["id"][:100].to_numpy()
182+
offsets = np.array([0, 10, 15, 20, 30, 40, 45, 55, 65, 75, 80, 90, 100])
183+
tensor_df = TensorTable({"id": TensorColumn(embedding_ids, offsets=offsets)}).to_df()
184+
tensor_df.to_parquet(pq_path)
185+
dataset = Dataset(str(pq_path), cpu=bool(cpu))
201186
dataset = dataset.repartition(10)
202187
schema = dataset.schema
203-
for col_name in cat_names:
204-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
205-
dataset.schema = schema
206-
207188
for col_name in cat_names:
208189
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
209190
dataset.schema = schema
210191
paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*"))
211192
embeddings_ds = Dataset(paths)
212-
np_tensor = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:]
193+
embeddings_np = embeddings_ds.to_ddf().compute().to_numpy()[:100, 1:]
213194
data_loader = Loader(
214195
dataset,
215196
batch_size=batch_size,
216-
transforms=[EmbeddingOperator(np_tensor)],
197+
transforms=[EmbeddingOperator(embeddings_np, id_lookup_table=embedding_ids)],
217198
shuffle=False,
218199
device=cpu,
219200
)
220201
full_len = 0
202+
old_end = 0
221203
for idx, batch in enumerate(data_loader):
222204
assert "embeddings" in batch[0]
223205
assert "id" in batch[0]
224-
start = idx * batch_size
225-
end = start + int(batch[0]["id"].shape[0])
206+
start = old_end
207+
end = start + int(batch[0]["id"].cpu().values.shape[0])
208+
old_end = end
209+
id_offsets = batch[0]["id"].cpu().offsets
226210
embeddings_vals = batch[0]["embeddings"].cpu().values
227-
assert (embeddings_vals == np_tensor[start:end]).all()
228-
full_len += int(batch[0]["embeddings"].shape[0])
229-
assert full_len == embedding_ids.shape[0]
230-
231-
232-
@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"])
233-
def test_embedding_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_dataframe, cpu):
234-
cat_names = ["id"]
235-
batch_size = 10000
236-
pq_path = tmpdir / "id.parquet"
237-
embedding_ids = rev_embedding_ids
238-
embedding_ids.to_parquet(pq_path)
239-
dataset = Dataset(str(pq_path))
240-
dataset = dataset.repartition(10)
241-
schema = dataset.schema
242-
for col_name in cat_names:
243-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
244-
dataset.schema = schema
245-
246-
for col_name in cat_names:
247-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
248-
dataset.schema = schema
249-
paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*"))
250-
embeddings_ds = Dataset(paths)
251-
np_tensor = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:]
252-
data_loader = Loader(
253-
dataset,
254-
batch_size=batch_size,
255-
transforms=[EmbeddingOperator(np_tensor, id_lookup_table=embedding_ids.to_numpy())],
256-
shuffle=False,
257-
device=cpu,
258-
)
259-
full_len = 0
260-
for idx, batch in enumerate(data_loader):
261-
assert "embeddings" in batch[0]
262-
assert "id" in batch[0]
263-
start = idx * batch_size
264-
end = start + int(batch[0]["id"].shape[0])
265-
embeddings_vals = batch[0]["embeddings"].cpu().values
266-
assert (embeddings_vals == np_tensor[start:end]).all()
211+
embeddings_offs = batch[0]["embeddings"].cpu().offsets
212+
assert (embeddings_vals == embeddings_np[start:end]).all()
213+
assert (embeddings_offs == id_offsets).all()
267214
full_len += int(batch[0]["embeddings"].shape[0])
268-
assert full_len == embedding_ids.shape[0]
215+
assert full_len == offsets.shape[0] - 1

0 commit comments

Comments
 (0)