Skip to content

Commit 70e0e3d

Browse files
[cherry-pick] Mechanism that converts startup_program initializers to BF16 (#32720) (#32764)
* Add casting initializers for bf16 training * Changes after review * Correct test and add comment Co-authored-by: joanna.wozna.intel <[email protected]>
1 parent 5fdd85b commit 70e0e3d

File tree

9 files changed

+131
-30
lines changed

9 files changed

+131
-30
lines changed

python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self,
4949
self.bf16_list = copy.copy(bf16_list)
5050
self.fp32_list = copy.copy(fp32_list)
5151
self.gray_list = copy.copy(gray_list)
52+
self.bf16_initializer_list = copy.copy(bf16_initializer_list)
5253
self.unsupported_list = copy.copy(unsupported_list)
5354
self.fp32_varnames = copy.copy(custom_fp32_varnames)
5455
self._update_list()
@@ -79,6 +80,8 @@ def _update_list(self):
7980
self.unsupported_list.add(op_name)
8081

8182

83+
bf16_initializer_list = {'fill_constant', 'uniform_random'}
84+
8285
# always bf16
8386
bf16_list = {'elementwise_add', }
8487

python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,52 @@ def bf16_guard():
232232
yield
233233

234234

235-
def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True):
235+
def are_post_ops_bf16(post_ops, keep_fp32_ops):
236+
for post_op in post_ops:
237+
for op in post_op:
238+
if op.type in keep_fp32_ops:
239+
return False
240+
return True
241+
242+
243+
def cast_initializers_to_bf16(startup_prog,
244+
amp_lists,
245+
block,
246+
all_ops,
247+
keep_fp32_ops,
248+
to_bf16_var_names=None):
249+
prepend_ops = startup_prog.global_block().ops
250+
for op in prepend_ops:
251+
if str(op.type) in amp_lists.bf16_initializer_list:
252+
change_op = True
253+
op_post_ops = []
254+
op_out_vars = []
255+
for out_name in op.output_names:
256+
for out_var_name in op.output(out_name):
257+
out_var = block.var(out_var_name)
258+
post_op = find_true_post_op(all_ops, op, out_var_name, True)
259+
260+
if out_var is None or out_var.type not in _valid_types:
261+
change_op = False
262+
break
263+
op_post_ops.append(post_op)
264+
op_out_vars.append(out_var)
265+
266+
if change_op and are_post_ops_bf16(op_post_ops, keep_fp32_ops):
267+
for out_var in op_out_vars:
268+
if out_var.dtype == core.VarDesc.VarType.FP32:
269+
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
270+
if to_bf16_var_names is not None and out_var.name in to_bf16_var_names:
271+
to_bf16_var_names.remove(out_var.name)
272+
if op.has_attr('dtype') and op.attr(
273+
'dtype') == core.VarDesc.VarType.FP32:
274+
op._set_attr('dtype', core.VarDesc.VarType.BF16)
275+
276+
277+
def cast_model_to_bf16(program,
278+
startup_prog=None,
279+
amp_lists=None,
280+
use_bf16_guard=True):
236281
"""
237282
Traverse all ops in the whole model and set their inputs and outputs
238283
to the bf16 data type. This function will do some special processing for
@@ -329,6 +374,10 @@ def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True):
329374
if op.has_attr('mkldnn_data_type'):
330375
op._set_attr('mkldnn_data_type', 'bfloat16')
331376

377+
if startup_prog is not None:
378+
cast_initializers_to_bf16(startup_prog, amp_lists, global_block,
379+
ops, keep_fp32_ops, to_bf16_var_names)
380+
332381
# process ops in keep_fp32_ops
333382
op_var_rename_map = [
334383
collections.OrderedDict() for _ in range(len(program.blocks))

python/paddle/fluid/contrib/mixed_precision/bf16/decorator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def backward(self,
9494

9595
if self._use_pure_bf16:
9696
self._to_bf16_var_names = cast_model_to_bf16(
97-
self._train_program, self._amp_lists, self._use_bf16_guard)
97+
self._train_program, startup_program, self._amp_lists,
98+
self._use_bf16_guard)
9899
else:
99100
rewrite_program_bf16(self._train_program, self._amp_lists)
100101

@@ -168,10 +169,12 @@ def run_example_code():
168169
self._to_bf16_var_names)
169170
if test_program is not None:
170171
if self._use_pure_bf16:
171-
cast_model_to_bf16(test_program, self._amp_lists,
172-
self._use_bf16_guard)
172+
cast_model_to_bf16(
173+
test_program,
174+
amp_lists=self._amp_lists,
175+
use_bf16_guard=self._use_bf16_guard)
173176
elif use_bf16_test:
174-
rewrite_program_bf16(test_program, self._amp_lists)
177+
rewrite_program_bf16(test_program, amp_lists=self._amp_lists)
175178

176179
def apply_gradients(self, params_grads):
177180
"""

python/paddle/fluid/contrib/mixed_precision/fp16_utils.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
157157
return num_cast_ops
158158

159159
assert target_var.dtype == src_dtype, \
160-
"The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
160+
"The real dtype({}) is not equal to the src dtype({})".format(
161+
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
161162

162163
cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
163164
cast_var = block.vars.get(cast_name)
@@ -209,19 +210,30 @@ def find_true_prev_op(ops, cur_op, var_name):
209210
return None
210211

211212

212-
def find_true_post_op(ops, cur_op, var_name):
213+
def find_true_post_op(ops, cur_op, var_name, search_all=False):
213214
"""
214215
if there are post ops, return them, if there is no post op,
215216
return None instead.
216217
Args:
217218
ops (list): A list of ops.
218219
cur_op (Operator): Current operator which has var_name variable.
219220
var_name (string): Variable name.
221+
search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set.
220222
"""
221223
post_op = []
222-
for idx, op in enumerate(ops):
223-
if op == cur_op:
224-
break
224+
if search_all:
225+
"""
226+
\"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come
227+
from startup_prog block and \"ops\" list from main_prog block.
228+
By setting idx to -1, we'll start looking for post-ops from the top of the list.
229+
If search_all is False, assume that \"cur_op\" is in \"ops\" list,
230+
so to reduce the time of search we can start iterating from \"cur_op\" idx.
231+
"""
232+
idx = -1
233+
else:
234+
for idx, op in enumerate(ops):
235+
if op == cur_op:
236+
break
225237

226238
for i in range(idx + 1, len(ops)):
227239
op = ops[i]
@@ -270,7 +282,7 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
270282

271283
if use_fp16_guard:
272284
if op.has_attr("op_namescope") and \
273-
(_fp16_guard_pattern in op.attr("op_namescope")):
285+
(_fp16_guard_pattern in op.attr("op_namescope")):
274286
# op in fp16 guard
275287
return False
276288
else:
@@ -496,8 +508,8 @@ def rewrite_program(main_prog, amp_lists):
496508
black_op_set = set()
497509
for op in ops:
498510

499-
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
500-
# we don't need to handle reader op and the input of 'create_py_reader' is not
511+
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
512+
# we don't need to handle reader op and the input of 'create_py_reader' is not
501513
# in block, which may result in errors.
502514
# See GeneratorLoader._init_non_iterable() for details.
503515
if op.type == 'create_py_reader' or op.type == 'read':
@@ -612,7 +624,7 @@ def update_role_var_grad(main_prog, params_grads):
612624
raise ValueError("The cast op {0}'s output should not be"
613625
"used by a non-optimize op, however, it"
614626
"is used by {1}".format(op, post_ops[0]))
615-
#add new op in the python and cpp at the same time
627+
# add new op in the python and cpp at the same time
616628
new_op_desc = block.desc.append_op()
617629
new_op_desc.copy_from(op.desc)
618630
new_op = framework.Operator(

python/paddle/fluid/contrib/tests/test_bf16_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,29 @@ def test_find_true_post_op(self):
139139
res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y")
140140
assert (res == [op2])
141141

142+
def test_find_true_post_op_with_search_all(self):
143+
program = fluid.Program()
144+
block = program.current_block()
145+
startup_block = fluid.default_startup_program().global_block()
146+
147+
var1 = block.create_var(name="X", shape=[3], dtype='float32')
148+
var2 = block.create_var(name="Y", shape=[3], dtype='float32')
149+
inititializer_op = startup_block._prepend_op(
150+
type="fill_constant",
151+
outputs={"Out": var1},
152+
attrs={"shape": var1.shape,
153+
"dtype": var1.dtype,
154+
"value": 1.0})
155+
156+
op1 = block.append_op(
157+
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
158+
result = amp.bf16.amp_utils.find_true_post_op(
159+
block.ops, inititializer_op, "X", search_all=False)
160+
assert (len(result) == 0)
161+
result = amp.bf16.amp_utils.find_true_post_op(
162+
block.ops, inititializer_op, "X", search_all=True)
163+
assert (result == [op1])
164+
142165

143166
if __name__ == '__main__':
144167
unittest.main()

python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,27 @@ def scope_prog_guard(self):
5353
with fluid.program_guard(prog, startup_prog):
5454
yield
5555

56-
def get_static_graph_result(self, feed, fetch_list, amp_fun,
57-
with_lod=False):
56+
def get_static_graph_result(self,
57+
feed,
58+
fetch_list,
59+
amp_fun,
60+
with_lod=False,
61+
startup_prog=None):
5862
exe = fluid.Executor(core.CPUPlace())
59-
exe.run(fluid.default_startup_program())
63+
exe.run(fluid.default_startup_program()
64+
if startup_prog is None else startup_prog)
6065
prog = fluid.default_main_program()
6166
if amp_fun is not None:
62-
amp_fun(prog)
67+
if startup_prog is not None:
68+
amp_fun(prog, startup_prog)
69+
else:
70+
amp_fun(prog)
6371
return exe.run(prog,
6472
feed=feed,
6573
fetch_list=fetch_list,
6674
return_numpy=(not with_lod))
6775

68-
def _graph_common(self, _amp_fun):
76+
def _graph_common(self, _amp_fun, startup_prog=None):
6977
size = 3
7078
n = np.ones([size, size], dtype='float32') * 3.2
7179
nn = np.ones([size, size], dtype='float32') * -2.7
@@ -122,7 +130,8 @@ def _graph_common(self, _amp_fun):
122130
self.get_static_graph_result(
123131
feed={'t': n, 'tt': nn},
124132
fetch_list=[ret],
125-
amp_fun=_amp_fun
133+
amp_fun=_amp_fun,
134+
startup_prog=startup_prog
126135
)
127136
self.assertTrue(
128137
static_ret_bf16, np.ones(
@@ -132,16 +141,17 @@ def test_graph_rewrite(self):
132141
self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16(
133142
prog,
134143
amp.bf16.AutoMixedPrecisionListsBF16(
135-
custom_fp32_varnames={'elementwise_add_0.tmp_0'}),
144+
custom_fp32_varnames={'elementwise_add_0.tmp_0'})
136145
))
137146

138147
def test_graph_cast(self):
139-
self._graph_common(lambda prog: amp.bf16.cast_model_to_bf16(
148+
self._graph_common(lambda prog, startup_prog: amp.bf16.cast_model_to_bf16(
140149
prog,
150+
startup_prog,
141151
amp.bf16.AutoMixedPrecisionListsBF16(
142152
custom_fp32_list={'elementwise_mul'}),
143153
use_bf16_guard=True
144-
))
154+
), startup_prog=fluid.default_startup_program())
145155

146156

147157
if __name__ == '__main__':

python/paddle/fluid/layers/tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,13 @@ def cast(x, dtype):
231231
out = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
232232
return out
233233

234-
check_variable_and_dtype(
235-
x, 'x',
236-
['bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
237-
'cast')
234+
check_variable_and_dtype(x, 'x', [
235+
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8',
236+
'uint16'
237+
], 'cast')
238238
check_dtype(dtype, 'dtype', [
239239
'bool', 'float16', 'float32', 'float64', 'int8', 'int32', 'int64',
240-
'uint8'
240+
'uint8', 'uint16'
241241
], 'cast')
242242

243243
helper = LayerHelper('cast', **locals())

python/paddle/fluid/tests/book/test_fit_a_line.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16):
5656
amp_lists=amp.bf16.AutoMixedPrecisionListsBF16(),
5757
use_bf16_guard=False,
5858
use_pure_bf16=pure_bf16)
59-
sgd_optimizer.minimize(avg_cost)
59+
sgd_optimizer.minimize(
60+
avg_cost, startup_program=fluid.default_startup_program())
6061

6162
BATCH_SIZE = 20
6263

python/paddle/fluid/tests/book/test_word2vec_book.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __network__(words):
115115
use_bf16_guard=False,
116116
use_pure_bf16=pure_bf16)
117117

118-
sgd_optimizer.minimize(avg_cost)
118+
sgd_optimizer.minimize(avg_cost, fluid.default_startup_program())
119119

120120
train_reader = paddle.batch(
121121
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)

0 commit comments

Comments
 (0)