Skip to content

Commit 0ba1205

Browse files
authored
Merge pull request #1414 from wangxicoding/fix_static_bert_amp_nan
fix static bert amp nan
2 parents f3010d9 + 97e51ff commit 0ba1205

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/language_model/bert/static/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __getitem__(self, index):
125125
# TODO: whether to use reversed mask by changing 1s and 0s to be
126126
# consistent with nv bert
127127
input_mask = (1 - np.reshape(
128-
input_mask.astype(np.float32), [1, 1, input_mask.shape[0]])) * -1e9
128+
input_mask.astype(np.float32), [1, 1, input_mask.shape[0]])) * -1e4
129129

130130
index = self.max_pred_length
131131
# store number of masked tokens in index

examples/language_model/bert/static/run_pretrain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def parse_args():
136136
parser.add_argument(
137137
"--scale_loss",
138138
type=float,
139-
default=1.0,
139+
default=2**15,
140140
help="The value of scale_loss for fp16.")
141141
parser.add_argument(
142142
"--use_pure_fp16",

0 commit comments

Comments
 (0)