Skip to content

Commit f7bc901

Browse files
authored
[Distributed] Integrate HybridBackend in collective training mode. (#912)
Signed-off-by: JunqiHu <[email protected]>
1 parent 14c25d8 commit f7bc901

File tree

8 files changed

+1959
-5
lines changed

8 files changed

+1959
-5
lines changed

tensorflow/python/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ py_library(
170170
"//tensorflow/python/compiler",
171171
"//tensorflow/python/data",
172172
"//tensorflow/python/distribute",
173+
"//tensorflow/python/distribute:deeprec_collective",
173174
"//tensorflow/python/distribute:combinations",
174175
"//tensorflow/python/distribute:distribute_config",
175176
"//tensorflow/python/distribute:estimator_training",

tensorflow/python/distribute/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@ package(
1010

1111
exports_files(["LICENSE"])
1212

13+
py_library(
14+
name = "deeprec_collective",
15+
srcs = [
16+
"group_embedding_collective_strategy.py",
17+
"launch.py",
18+
"hvd_strategy.py",
19+
],
20+
srcs_version = "PY2AND3",
21+
deps = [
22+
"//tensorflow/python:array_ops",
23+
"//tensorflow/python:framework_ops",
24+
"//tensorflow/python:math_ops",
25+
"//tensorflow/python:nccl_ops",
26+
],
27+
)
28+
1329
py_library(
1430
name = "distribute_test_lib_pip",
1531
deps = [

tensorflow/python/distribute/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorflow.python.distribute import distribution_strategy_context
2626
from tensorflow.python.distribute import mirrored_strategy
2727
from tensorflow.python.distribute import one_device_strategy
28+
from tensorflow.python.distribute import launch
2829
from tensorflow.python.distribute.experimental import collective_all_reduce_strategy
2930
from tensorflow.python.distribute.experimental import parameter_server_strategy
3031
# pylint: enable=unused-import
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2023 Alibaba Group Holding Limited. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# =============================================================================
17+
18+
19+
from tensorflow.python.framework.group_embedding_types import (
20+
DistStrategy,
21+
set_group_lookup_strategy,
22+
)
23+
24+
import os
25+
import contextlib
26+
from tensorflow_estimator.python.estimator import estimator as _estimator_lib
27+
28+
29+
class CollectiveStrategy:
30+
r"""
31+
A thin interface to all kinds of Synchonized training strategy.
32+
"""
33+
34+
def __init__(self):
35+
self._hvd = None
36+
self._hb = None
37+
strategy = os.getenv("COLLECTIVE_STRATEGY", "sok")
38+
if strategy == DistStrategy.SOK.value:
39+
try:
40+
import horovod.tensorflow as hvd
41+
hvd.init()
42+
from sparse_operation_kit import experiment as sok
43+
sok.init()
44+
except:
45+
raise ImportError(
46+
"While param `strategy` in enable_distributed_strategyis given `sok`,"
47+
" sok module initialize error,please double check"
48+
)
49+
50+
self._sok = sok
51+
self._hvd = hvd
52+
set_group_lookup_strategy(strategy)
53+
elif strategy == DistStrategy.HB.value:
54+
try:
55+
import hybridbackend.tensorflow as hb
56+
except:
57+
raise ImportError(
58+
"While param `strategy` in enable_distributed_strategyis given `hb`, hb module initialize error,please double check"
59+
)
60+
self._hb = hb
61+
set_group_lookup_strategy(strategy)
62+
else:
63+
raise ValueError(
64+
"accepted `COLLECTIVE_STRATEGY` is sok or hb, while given %s", strategy
65+
)
66+
67+
@contextlib.contextmanager
68+
def scope(self, *args, **kwargs):
69+
if self._hvd:
70+
from tensorflow.python.distribute import hvd_strategy
71+
with hvd_strategy.scope() as context:
72+
yield context
73+
elif self._hb:
74+
with self._hb.scope() as context:
75+
yield context
76+
77+
@contextlib.contextmanager
78+
def embedding_scope(self, **kwargs):
79+
if self._hvd:
80+
from tensorflow.python.distribute import hvd_strategy
81+
with hvd_strategy.embedding_scope() as context:
82+
yield context
83+
elif self._hb:
84+
with self._hb.embedding_scope() as context:
85+
yield context
86+
87+
def world_size(self):
88+
if self._hvd:
89+
return self._hvd.size()
90+
elif self._hb:
91+
return self._hb.context.world_size
92+
93+
def rank(self):
94+
if self._hvd:
95+
return self._hvd.rank()
96+
elif self._hb:
97+
return self._hb.context.rank
98+
99+
def estimator(self, model_fn, **kwargs):
100+
if self._hvd:
101+
from tensorflow.python.distribute.hvd_strategy import wraps_estimator
102+
_estimator = wraps_estimator(_estimator_lib.Estimator)
103+
elif self._hb:
104+
_estimator = hb.estimator.Estimator
105+
106+
return _estimator(model_fn, **kwargs)
107+
108+
def export_saved_model(
109+
self,
110+
savedmodel_dir,
111+
checkpoint_dir=None,
112+
signature_def_fn=None,
113+
assets_extra=None,
114+
as_text=False,
115+
clear_devices=True,
116+
strip_default_attrs=True,
117+
):
118+
if self._hvd:
119+
from tensorflow.python.distribute import hvd_strategy
120+
hvd_strategy.export(
121+
savedmodel_dir,
122+
checkpoint_dir,
123+
signature_def_fn,
124+
assets_extra,
125+
as_text,
126+
clear_devices,
127+
strip_default_attrs,
128+
)
129+
elif self._hb:
130+
self._hb.train.export(
131+
savedmodel_dir,
132+
checkpoint_dir,
133+
signature_def_fn,
134+
assets_extra,
135+
as_text,
136+
clear_devices,
137+
strip_default_attrs,
138+
)

0 commit comments

Comments
 (0)