Skip to content

Commit da7e1a7

Browse files
tgrelnv-kkudrynski
authored andcommitted
[DLRM/TF2] CPU offloading
1 parent 41f582b commit da7e1a7

File tree

8 files changed

+69
-26
lines changed

8 files changed

+69
-26
lines changed

TensorFlow2/Recommendation/DLRM_and_DCNv2/Dockerfile

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
#
1515
# author: Tomasz Grel ([email protected])
1616

17-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:23.02-tf2-py3
18-
FROM ${FROM_IMAGE_NAME}
17+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:23.06-tf2-py3
18+
FROM nvcr.io/nvidia/tritonserver:23.06-py3-sdk as clientsdk
19+
FROM ${FROM_IMAGE_NAME} as base
1920

20-
ARG DISTRIBUTED_EMBEDDINGS_COMMIT=c635ed84
21+
ARG DISTRIBUTED_EMBEDDINGS_COMMIT=45cffaa8
2122

2223
WORKDIR /dlrm
2324

@@ -38,8 +39,7 @@ RUN rm -rf distributed-embeddings &&\
3839
pip install artifacts/*.whl &&\
3940
cd ..
4041

41-
ADD . .
42-
42+
ADD tensorflow-dot-based-interact tensorflow-dot-based-interact
4343
RUN mkdir -p /usr/local/lib/python3.8/dist-packages/tensorflow/include/third_party/gpus/cuda/ &&\
4444
cd tensorflow-dot-based-interact &&\
4545
make clean &&\
@@ -49,5 +49,13 @@ RUN mkdir -p /usr/local/lib/python3.8/dist-packages/tensorflow/include/third_par
4949
pip install ./artifacts/tensorflow_dot_based_interact-*.whl &&\
5050
cd ..
5151

52+
COPY --from=clientsdk /workspace/install/python/tritonclient-2.35.0-py3-*.whl /dlrm/
53+
RUN if [[ "$(uname -m)" == "x86_64" ]]; \
54+
then echo x86; pip install tritonclient-2.35.0-py3-none-manylinux1_x86_64.whl[all]; \
55+
else echo arm; pip install tritonclient-2.35.0-py3-none-manylinux2014_aarch64.whl[all]; \
56+
fi
57+
58+
ADD . .
59+
5260
ENV HOROVOD_CYCLE_TIME=0.2
5361
ENV HOROVOD_ENABLE_ASYNC_COMPLETION=1

TensorFlow2/Recommendation/DLRM_and_DCNv2/dataloading/dataloader.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def _create_pipelines_tf_raw(**kwargs):
8080
local_categorical_feature_names=local_categorical_names,
8181
rank=kwargs['rank'],
8282
world_size=kwargs['world_size'],
83-
concat_features=kwargs['concat_features'])
83+
concat_features=kwargs['concat_features'],
84+
data_parallel_categoricals=kwargs['data_parallel_input'])
8485

8586
test_dataset = TfRawBinaryDataset(feature_spec=feature_spec,
8687
instance=TEST_MAPPING,
@@ -89,7 +90,8 @@ def _create_pipelines_tf_raw(**kwargs):
8990
local_categorical_feature_names=local_categorical_names,
9091
rank=kwargs['rank'],
9192
world_size=kwargs['world_size'],
92-
concat_features=kwargs['concat_features'])
93+
concat_features=kwargs['concat_features'],
94+
data_parallel_categoricals=kwargs['data_parallel_input'])
9395
return train_dataset, test_dataset
9496

9597

@@ -113,7 +115,8 @@ def _create_pipelines_split_tfrecords(**kwargs):
113115

114116

115117
def create_input_pipelines(dataset_type, dataset_path, train_batch_size, test_batch_size,
116-
table_ids, feature_spec, rank=0, world_size=1, concat_features=False):
118+
table_ids, feature_spec, rank=0, world_size=1, concat_features=False,
119+
data_parallel_input=False):
117120

118121
# pass along all arguments except dataset type
119122
kwargs = locals()

TensorFlow2/Recommendation/DLRM_and_DCNv2/dataloading/raw_binary_dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,14 @@ def __init__(
6767
numerical_features_enabled: bool = False,
6868
rank: int = 0,
6969
world_size: int = 1,
70-
concat_features: bool = False
70+
concat_features: bool = False,
71+
data_parallel_categoricals = False,
7172
):
7273

7374
self._concat_features = concat_features
7475
self._feature_spec = feature_spec
7576
self._batch_size = batch_size
77+
self._data_parallel_categoricals = data_parallel_categoricals
7678

7779
local_batch_size = int(batch_size / world_size)
7880
batch_sizes_per_gpu = [local_batch_size] * world_size
@@ -180,6 +182,8 @@ def decode_batch(self, labels, numerical_features, categorical_features, concat_
180182
feature = tf.cast(feature, dtype=tf.int32)
181183
feature = tf.expand_dims(feature, axis=1)
182184
feature = tf.reshape(feature, [self._batch_size, 1])
185+
if self._data_parallel_categoricals:
186+
feature = feature[self.dp_begin_idx:self.dp_end_idx]
183187
cat_data.append(feature)
184188
if self._concat_features:
185189
cat_data = tf.concat(cat_data, axis=1)

TensorFlow2/Recommendation/DLRM_and_DCNv2/deployment/tf/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
# author: Tomasz Grel ([email protected])
1616

1717

18-
FROM nvcr.io/nvidia/tritonserver:23.02-py3 as tritonserver
18+
FROM nvcr.io/nvidia/tritonserver:23.06-py3 as tritonserver
1919

2020
WORKDIR /opt/tritonserver

TensorFlow2/Recommendation/DLRM_and_DCNv2/main.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ def define_common_flags():
5050
flags.DEFINE_string("dist_strategy", default='memory_balanced',
5151
help="Strategy for the Distributed Embeddings to use. Supported options are"
5252
"'memory_balanced', 'basic' and 'memory_optimized'")
53-
flags.DEFINE_integer("column_slice_threshold", default=10*1000*1000*1000,
53+
flags.DEFINE_integer("column_slice_threshold", default=5*1000*1000*1000,
54+
help='Number of elements above which a distributed embedding will be sliced across'
55+
'multiple devices')
56+
flags.DEFINE_integer("row_slice_threshold", default=10*1000*1000*1000,
57+
help='Number of elements above which a distributed embedding will be sliced across'
58+
'multiple devices')
59+
flags.DEFINE_integer("data_parallel_threshold", default=None,
5460
help='Number of elements above which a distributed embedding will be sliced across'
5561
'multiple devices')
5662

@@ -97,6 +103,8 @@ def define_common_flags():
97103
flags.DEFINE_enum("dataset_type", default="tf_raw",
98104
enum_values=['tf_raw', 'synthetic', 'split_tfrecords'],
99105
help='The type of the dataset to use')
106+
flags.DEFINE_boolean("data_parallel_input", default=False, help="Use a data-parallel dataloader,"
107+
" i.e., load a local batch of of data for all input features")
100108

101109
# Synthetic dataset settings
102110
flags.DEFINE_boolean("synthetic_dataset_use_feature_spec", default=False,
@@ -296,14 +304,18 @@ def main():
296304
categorical_cardinalities=dataset_metadata.categorical_cardinalities,
297305
transpose=False)
298306

307+
table_ids = model.sparse_model.get_local_table_ids(hvd.rank())
308+
print(f'local feature ids={table_ids}')
309+
299310
train_pipeline, validation_pipeline = create_input_pipelines(dataset_type=FLAGS.dataset_type,
300311
dataset_path=FLAGS.dataset_path,
301312
train_batch_size=FLAGS.batch_size,
302313
test_batch_size=FLAGS.valid_batch_size,
303-
table_ids=model.sparse_model.get_local_table_ids(hvd.rank()),
314+
table_ids=table_ids,
304315
feature_spec=FLAGS.feature_spec,
305316
rank=hvd.rank(), world_size=hvd.size(),
306-
concat_features=FLAGS.concat_embedding)
317+
concat_features=FLAGS.concat_embedding,
318+
data_parallel_input=FLAGS.data_parallel_input)
307319

308320
mlp_optimizer, embedding_optimizer = create_optimizers(FLAGS)
309321

TensorFlow2/Recommendation/DLRM_and_DCNv2/nn/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _find_first_gpu_index(self):
265265
reversed_sizes = self.table_sizes[idx_mapping]
266266

267267
cumulative_size = np.cumsum(reversed_sizes)
268-
cumulative_indicators = (cumulative_size > self.memory_threshold * 2 ** 30).tolist()
268+
cumulative_indicators = (cumulative_size > self.memory_threshold * (10 ** 9)).tolist()
269269
if True in cumulative_indicators:
270270
index = cumulative_indicators.index(True)
271271
else:

TensorFlow2/Recommendation/DLRM_and_DCNv2/nn/sparse_model.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import json
2121

2222
from distributed_embeddings.python.layers import dist_model_parallel as dmp
23-
from distributed_embeddings.python.layers import embedding
2423

2524
from utils.checkpointing import get_variable_path
2625

@@ -29,7 +28,19 @@
2928

3029
sparse_model_parameters = ['use_mde_embeddings', 'embedding_dim', 'column_slice_threshold',
3130
'embedding_zeros_initializer', 'embedding_trainable', 'categorical_cardinalities',
32-
'concat_embedding', 'cpu_offloading_threshold_gb']
31+
'concat_embedding', 'cpu_offloading_threshold_gb',
32+
'data_parallel_input', 'row_slice_threshold', 'data_parallel_threshold']
33+
34+
def _gigabytes_to_elements(gb, dtype=tf.float32):
35+
if gb is None:
36+
return None
37+
38+
if dtype == tf.float32:
39+
bytes_per_element = 4
40+
else:
41+
raise ValueError(f'Unsupported dtype: {dtype}')
42+
43+
return gb * 10**9 / bytes_per_element
3344

3445
class SparseModel(tf.keras.Model):
3546
def __init__(self, **kwargs):
@@ -61,21 +72,21 @@ def _create_embeddings(self):
6172
for table_size, dim in zip(self.categorical_cardinalities, self.embedding_dim):
6273
if hvd.rank() == 0:
6374
print(f'Creating embedding with size: {table_size} {dim}')
64-
if self.use_mde_embeddings:
65-
e = embedding.Embedding(input_dim=table_size, output_dim=dim,
66-
combiner='sum', embeddings_initializer=initializer_cls())
67-
else:
68-
e = tf.keras.layers.Embedding(input_dim=table_size, output_dim=dim,
69-
embeddings_initializer=initializer_cls())
75+
e = tf.keras.layers.Embedding(input_dim=table_size, output_dim=dim,
76+
embeddings_initializer=initializer_cls())
7077
self.embedding_layers.append(e)
7178

79+
gpu_size = _gigabytes_to_elements(self.cpu_offloading_threshold_gb)
7280
self.embedding = dmp.DistributedEmbedding(self.embedding_layers,
7381
strategy='memory_balanced',
74-
dp_input=False,
75-
column_slice_threshold=self.column_slice_threshold)
82+
dp_input=self.data_parallel_input,
83+
column_slice_threshold=self.column_slice_threshold,
84+
row_slice_threshold=self.row_slice_threshold,
85+
data_parallel_threshold=self.data_parallel_threshold,
86+
gpu_embedding_size=gpu_size)
7687

7788
def get_local_table_ids(self, rank):
78-
if self.use_concat_embedding:
89+
if self.use_concat_embedding or self.data_parallel_input:
7990
return list(range(self.num_all_categorical_features))
8091
else:
8192
return self.embedding.strategy.input_ids_list[rank]
@@ -127,4 +138,10 @@ def save_config(self, path):
127138
def from_config(path):
128139
with open(path) as f:
129140
config = json.load(fp=f)
141+
if 'data_parallel_input' not in config:
142+
config['data_parallel_input'] = False
143+
if 'row_slice_threshold' not in config:
144+
config['row_slice_threshold'] = None
145+
if 'data_parallel_threshold' not in config:
146+
config['data_parallel_threshold'] = None
130147
return SparseModel(**config)

TensorFlow2/Recommendation/DLRM_and_DCNv2/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ tqdm
77
pyyaml
88
onnxruntime
99
git+https://github.com/onnx/tensorflow-onnx
10-
tritonclient[all]==2.31
1110
numpy<1.24
1211
tabulate>=0.8.7
1312
natsort>=7.0.0

0 commit comments

Comments
 (0)