Skip to content

Commit 3481402

Browse files
committed
Cast weights from fp16 to fp32 after loading.
1 parent 969939e commit 3481402

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

PaddleNLP/benchmark/transformer/static/predict.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,19 @@
2121
logger = logging.getLogger(__name__)
2222

2323

24+
def cast_parameters_to_fp32(place, program, scope=None):
25+
all_parameters = []
26+
for block in program.blocks:
27+
all_parameters.extend(block.all_parameters())
28+
29+
var_scope = scope if scope else paddle.static.global_scope()
30+
for param in all_parameters:
31+
tensor = var_scope.find_var(param.name).get_tensor()
32+
if 'FP16' in str(tensor._dtype()) and 'FP32' in str(param.dtype):
33+
data = np.array(tensor)
34+
tensor.set(np.float32(data), place)
35+
36+
2437
def parse_args():
2538
parser = argparse.ArgumentParser()
2639
parser.add_argument(
@@ -56,7 +69,7 @@ def do_train(args):
5669
place = paddle.set_device("cpu")
5770

5871
# Define data loader
59-
test_loader, to_tokens = reader.create_infer_loader(args)
72+
test_loader, to_tokens = reader.create_infer_loader(args, use_all_vocab=True)
6073

6174
test_program = paddle.static.Program()
6275
startup_program = paddle.static.Program()
@@ -93,6 +106,9 @@ def do_train(args):
93106
os.path.join(args.init_from_params, "transformer"), exe)
94107
print("finish initing model from params from %s" % (args.init_from_params))
95108

109+
# cast weights from fp16 to fp32 after loading
110+
cast_parameters_to_fp32(place, test_program)
111+
96112
f = open(args.output_file, "w")
97113
for data in test_loader:
98114
finished_sequence, = exe.run(test_program,

0 commit comments

Comments
 (0)