Skip to content

Commit e74d2a4

Browse files
committed
Cast weights from fp16 to fp32 after loading.
1 parent 969939e commit e74d2a4

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

PaddleNLP/benchmark/transformer/static/predict.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@
2121
logger = logging.getLogger(__name__)
2222

2323

24+
def cast_parameters_to_fp32(place, program, scope=None):
25+
all_parameters = []
26+
for block in program.blocks:
27+
all_parameters.extend(block.all_parameters())
28+
29+
var_scope = scope if scope else paddle.static.global_scope()
30+
for param in all_parameters:
31+
tensor = var_scope.find_var(param.name).get_tensor()
32+
if 'fp16' in str(tensor._dtype()).lower() and \
33+
'fp32' in str(param.dtype).lower():
34+
data = np.array(tensor)
35+
tensor.set(np.float32(data), place)
36+
37+
2438
def parse_args():
2539
parser = argparse.ArgumentParser()
2640
parser.add_argument(
@@ -56,7 +70,7 @@ def do_train(args):
5670
place = paddle.set_device("cpu")
5771

5872
# Define data loader
59-
test_loader, to_tokens = reader.create_infer_loader(args)
73+
test_loader, to_tokens = reader.create_infer_loader(args, use_all_vocab=True)
6074

6175
test_program = paddle.static.Program()
6276
startup_program = paddle.static.Program()
@@ -93,6 +107,9 @@ def do_train(args):
93107
os.path.join(args.init_from_params, "transformer"), exe)
94108
print("finish initing model from params from %s" % (args.init_from_params))
95109

110+
# cast weights from fp16 to fp32 after loading
111+
cast_parameters_to_fp32(place, test_program)
112+
96113
f = open(args.output_file, "w")
97114
for data in test_loader:
98115
finished_sequence, = exe.run(test_program,

0 commit comments

Comments
 (0)