Skip to content

Commit d2c2285

Browse files
Support stride dy2st (#6712)
1 parent 09329bd commit d2c2285

File tree

1 file changed

+49
-21
lines changed

1 file changed

+49
-21
lines changed

model_zoo/gpt-3/ppfleetx/models/language_model/gpt/dygraph/single_model.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,28 +1079,56 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
10791079
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
10801080
model_kwargs["attention_mask"] = paddle.reshape(attn_mask, paddle.shape(attn_mask))
10811081
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
1082-
while cur_len < max_length:
1083-
# Note(GuoxiaWang): Remove outputs = _forward_(**model_kwargs)
1084-
# and change it to pass directly to _post_process_ to avoid
1085-
# closed-loop problem of dynamic-to-static model
1086-
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
1087-
_forward_(**model_kwargs),
1088-
input_ids,
1089-
cur_len_gpu,
1090-
origin_len_gpu,
1091-
scores,
1092-
unfinished_flag,
1093-
model_kwargs,
1094-
)
1095-
if not self.inference:
1096-
cur_len += 1
1097-
else:
1098-
# Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static
1099-
paddle.increment(cur_len)
1100-
paddle.increment(cur_len_gpu)
1082+
if hasattr(paddle.framework, "_no_check_dy2st_diff"):
1083+
# TODO(wanghuancoder): _no_check_dy2st_diff is used to turn off the checking of behavior
1084+
# inconsistency between dynamic graph and static graph. _no_check_dy2st_diff should be
1085+
# removed after static graphs support inplace and stride.
1086+
with paddle.framework._no_check_dy2st_diff():
1087+
while cur_len < max_length:
1088+
# Note(GuoxiaWang): Remove outputs = _forward_(**model_kwargs)
1089+
# and change it to pass directly to _post_process_ to avoid
1090+
# closed-loop problem of dynamic-to-static model
1091+
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
1092+
_forward_(**model_kwargs),
1093+
input_ids,
1094+
cur_len_gpu,
1095+
origin_len_gpu,
1096+
scores,
1097+
unfinished_flag,
1098+
model_kwargs,
1099+
)
1100+
if not self.inference:
1101+
cur_len += 1
1102+
else:
1103+
# Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static
1104+
paddle.increment(cur_len)
1105+
paddle.increment(cur_len_gpu)
1106+
1107+
if not paddle.any(unfinished_flag):
1108+
break
1109+
else:
1110+
while cur_len < max_length:
1111+
# Note(GuoxiaWang): Remove outputs = _forward_(**model_kwargs)
1112+
# and change it to pass directly to _post_process_ to avoid
1113+
# closed-loop problem of dynamic-to-static model
1114+
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
1115+
_forward_(**model_kwargs),
1116+
input_ids,
1117+
cur_len_gpu,
1118+
origin_len_gpu,
1119+
scores,
1120+
unfinished_flag,
1121+
model_kwargs,
1122+
)
1123+
if not self.inference:
1124+
cur_len += 1
1125+
else:
1126+
# Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static
1127+
paddle.increment(cur_len)
1128+
paddle.increment(cur_len_gpu)
11011129

1102-
if not paddle.any(unfinished_flag):
1103-
break
1130+
if not paddle.any(unfinished_flag):
1131+
break
11041132

11051133
return model_kwargs["res"][:, origin_len:], scores
11061134

0 commit comments

Comments
 (0)