2626import warnings
2727import logging
2828import ast
29+ import numpy as np
30+ import struct
2931
3032__dir__ = os .path .dirname (os .path .abspath (__file__ ))
3133sys .path .append (os .path .abspath (os .path .join (__dir__ , '..' )))
@@ -60,6 +62,10 @@ def parse_args():
6062 return config
6163
6264
65+ def bf16_to_fp32 (val ):
66+ return np .float32 (struct .unpack ('<f' , struct .pack ('<I' , val << 16 ))[0 ])
67+
68+
6369class Main (object ):
6470 def __init__ (self , config ):
6571 self .metrics = {}
@@ -86,7 +92,7 @@ def run(self):
8692 def network (self ):
8793 self .model = get_model (self .config )
8894 self .input_data = self .model .create_feeds ()
89- self .inference_feed_var = self .model .create_feeds (is_infer = True )
95+ self .inference_feed_var = self .model .create_feeds (is_infer = False )
9096 self .init_reader ()
9197 self .metrics = self .model .net (self .input_data )
9298 self .inference_target_var = self .model .inference_target_var
@@ -212,7 +218,6 @@ def dataloader_train_loop(self, epoch):
212218 batch_id = 0
213219 train_run_cost = 0.0
214220 total_examples = 0
215- from paddle .fluid .tests .unittests .op_test import convert_uint16_to_float
216221 self .reader .start ()
217222 while True :
218223 try :
@@ -232,7 +237,7 @@ def dataloader_train_loop(self, epoch):
232237 metrics_string += "{}: {}, " .format (
233238 var_name , fetch_var [var_idx ]
234239 if var_name != "LOSS" or not config ['pure_bf16' ]
235- else convert_uint16_to_float (fetch_var [var_idx ]))
240+ else bf16_to_fp32 (fetch_var [var_idx ][ 0 ]))
236241 profiler_string = ""
237242 profiler_string += "avg_batch_cost: {} sec, " .format (
238243 format ((train_run_cost ) / print_step , '.5f' ))
0 commit comments