@@ -1079,28 +1079,56 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
1079
1079
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
1080
1080
model_kwargs ["attention_mask" ] = paddle .reshape (attn_mask , paddle .shape (attn_mask ))
1081
1081
model_kwargs ["cache" ] = outputs [1 ] if isinstance (outputs , tuple ) else None
1082
- while cur_len < max_length :
1083
- # Note(GuoxiaWang): Remove outputs = _forward_(**model_kwargs)
1084
- # and change it to pass directly to _post_process_ to avoid
1085
- # closed-loop problem of dynamic-to-static model
1086
- input_ids , scores , unfinished_flag , model_kwargs = _post_process_ (
1087
- _forward_ (** model_kwargs ),
1088
- input_ids ,
1089
- cur_len_gpu ,
1090
- origin_len_gpu ,
1091
- scores ,
1092
- unfinished_flag ,
1093
- model_kwargs ,
1094
- )
1095
- if not self .inference :
1096
- cur_len += 1
1097
- else :
1098
- # Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static
1099
- paddle .increment (cur_len )
1100
- paddle .increment (cur_len_gpu )
1082
+ if hasattr (paddle .framework , "_no_check_dy2st_diff" ):
1083
+ # TODO(wanghuancoder): _no_check_dy2st_diff is used to turn off the checking of behavior
1084
+ # inconsistency between dynamic graph and static graph. _no_check_dy2st_diff should be
1085
+ # removed after static graphs support inplace and stride.
1086
+ with paddle .framework ._no_check_dy2st_diff ():
1087
+ while cur_len < max_length :
1088
+ # Note(GuoxiaWang): Remove outputs = _forward_(**model_kwargs)
1089
+ # and change it to pass directly to _post_process_ to avoid
1090
+ # closed-loop problem of dynamic-to-static model
1091
+ input_ids , scores , unfinished_flag , model_kwargs = _post_process_ (
1092
+ _forward_ (** model_kwargs ),
1093
+ input_ids ,
1094
+ cur_len_gpu ,
1095
+ origin_len_gpu ,
1096
+ scores ,
1097
+ unfinished_flag ,
1098
+ model_kwargs ,
1099
+ )
1100
+ if not self .inference :
1101
+ cur_len += 1
1102
+ else :
1103
+ # Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static
1104
+ paddle .increment (cur_len )
1105
+ paddle .increment (cur_len_gpu )
1106
+
1107
+ if not paddle .any (unfinished_flag ):
1108
+ break
1109
+ else :
1110
+ while cur_len < max_length :
1111
+ # Note(GuoxiaWang): Remove outputs = _forward_(**model_kwargs)
1112
+ # and change it to pass directly to _post_process_ to avoid
1113
+ # closed-loop problem of dynamic-to-static model
1114
+ input_ids , scores , unfinished_flag , model_kwargs = _post_process_ (
1115
+ _forward_ (** model_kwargs ),
1116
+ input_ids ,
1117
+ cur_len_gpu ,
1118
+ origin_len_gpu ,
1119
+ scores ,
1120
+ unfinished_flag ,
1121
+ model_kwargs ,
1122
+ )
1123
+ if not self .inference :
1124
+ cur_len += 1
1125
+ else :
1126
+ # Note(ZhenyuLi): Avoid the synchronization caused by scale in dy2static
1127
+ paddle .increment (cur_len )
1128
+ paddle .increment (cur_len_gpu )
1101
1129
1102
- if not paddle .any (unfinished_flag ):
1103
- break
1130
+ if not paddle .any (unfinished_flag ):
1131
+ break
1104
1132
1105
1133
return model_kwargs ["res" ][:, origin_len :], scores
1106
1134
0 commit comments