Skip to content

Commit 98e523f

Browse files
yinhaofengroot
authored andcommitted
benchmark train
1 parent 354b077 commit 98e523f

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

models/recall/word2vec/benchmark/static_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def net(self, inputs, is_infer=False):
156156

157157
return self.metrics
158158

159-
def create_optimizer(self, strategy=None, pure_bf16=False):
159+
def create_optimizer(self, strategy=None):
160+
pure_bf16 = self.config.get("pure_bf16")
160161
lr = float(self.config.get("hyper_parameters.optimizer.learning_rate"))
161162
decay_rate = float(
162163
self.config.get("hyper_parameters.optimizer.decay_rate"))

tools/static_ps_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import warnings
2727
import logging
2828
import ast
29+
import numpy as np
30+
import struct
2931

3032
__dir__ = os.path.dirname(os.path.abspath(__file__))
3133
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
@@ -60,6 +62,10 @@ def parse_args():
6062
return config
6163

6264

65+
def bf16_to_fp32(val):
66+
return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
67+
68+
6369
class Main(object):
6470
def __init__(self, config):
6571
self.metrics = {}
@@ -86,7 +92,7 @@ def run(self):
8692
def network(self):
8793
self.model = get_model(self.config)
8894
self.input_data = self.model.create_feeds()
89-
self.inference_feed_var = self.model.create_feeds(is_infer=True)
95+
self.inference_feed_var = self.model.create_feeds(is_infer=False)
9096
self.init_reader()
9197
self.metrics = self.model.net(self.input_data)
9298
self.inference_target_var = self.model.inference_target_var
@@ -212,7 +218,6 @@ def dataloader_train_loop(self, epoch):
212218
batch_id = 0
213219
train_run_cost = 0.0
214220
total_examples = 0
215-
from paddle.fluid.tests.unittests.op_test import convert_uint16_to_float
216221
self.reader.start()
217222
while True:
218223
try:
@@ -232,7 +237,7 @@ def dataloader_train_loop(self, epoch):
232237
metrics_string += "{}: {}, ".format(
233238
var_name, fetch_var[var_idx]
234239
if var_name != "LOSS" or not config['pure_bf16']
235-
else convert_uint16_to_float(fetch_var[var_idx]))
240+
else bf16_to_fp32(fetch_var[var_idx][0]))
236241
profiler_string = ""
237242
profiler_string += "avg_batch_cost: {} sec, ".format(
238243
format((train_run_cost) / print_step, '.5f'))

0 commit comments

Comments
 (0)