Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion PaddleNLP/benchmark/transformer/static/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@
logger = logging.getLogger(__name__)


def cast_parameters_to_fp32(place, program, scope=None):
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())

var_scope = scope if scope else paddle.static.global_scope()
for param in all_parameters:
tensor = var_scope.find_var(param.name).get_tensor()
if 'fp16' in str(tensor._dtype()).lower() and \
'fp32' in str(param.dtype).lower():
data = np.array(tensor)
tensor.set(np.float32(data), place)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -56,7 +70,7 @@ def do_train(args):
place = paddle.set_device("cpu")

# Define data loader
test_loader, to_tokens = reader.create_infer_loader(args)
test_loader, to_tokens = reader.create_infer_loader(args, use_all_vocab=True)

test_program = paddle.static.Program()
startup_program = paddle.static.Program()
Expand Down Expand Up @@ -93,6 +107,9 @@ def do_train(args):
os.path.join(args.init_from_params, "transformer"), exe)
print("finish initing model from params from %s" % (args.init_from_params))

# cast weights from fp16 to fp32 after loading
cast_parameters_to_fp32(place, test_program)

f = open(args.output_file, "w")
for data in test_loader:
finished_sequence, = exe.run(test_program,
Expand Down