Skip to content

Commit 9b9267d

Browse files
committed
global metrics
1 parent 964bda7 commit 9b9267d

File tree

1 file changed

+223
-14
lines changed

1 file changed

+223
-14
lines changed

tools/ps_online_trainer.py

Lines changed: 223 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from __future__ import print_function
16-
16+
import math
1717
import random
1818
import numpy as np
1919
from pathlib import Path
@@ -59,10 +59,11 @@ def __init__(self, config):
5959
self.save_delta_frequency = config.get("runner.save_delta_frequency",
6060
6)
6161
self.checkpoint_per_pass = config.get("runner.checkpoint_per_pass", 6)
62-
self.save_first_base = config.get("runner.save_first_base", True)
62+
self.save_first_base = config.get("runner.save_first_base", False)
6363
self.start_day = config.get("runner.start_day")
6464
self.end_day = config.get("runner.end_day")
6565
self.save_model_path = self.config.get("runner.model_save_path")
66+
self.data_path = self.config.get("runner.train_data_dir")
6667
if config.get("runner.fs_client.uri") is not None:
6768
self.hadoop_config = {}
6869
for key in ["uri", "user", "passwd", "hadoop_bin"]:
@@ -71,9 +72,8 @@ def __init__(self, config):
7172
self.hadoop_fs_name = self.hadoop_config.get("uri")
7273
self.hadoop_fs_ugi = self.hadoop_config.get(
7374
"user") + "," + self.hadoop_config.get("passwd")
74-
prefix = "hdfs:/user/paddle/" if self.hadoop_fs_name.startswith(
75-
"hdfs:") else "afs:/user/paddle/"
76-
self.save_model_path = prefix + self.save_model_path.strip("/")
75+
# prefix = "hdfs:/" if self.hadoop_fs_name.startswith("hdfs:") else "afs:/"
76+
# self.save_model_path = prefix + self.save_model_path.strip("/")
7777
else:
7878
self.hadoop_fs_name, self.hadoop_fs_ugi = None, None
7979
self.train_local = self.hadoop_fs_name is None or self.hadoop_fs_ugi is None
@@ -283,6 +283,13 @@ def get_last_save_xbox_base(self,
283283
xbox_base_key = int(last_dict["key"])
284284
return [last_day, last_path, xbox_base_key]
285285

286+
def clear_metrics(self, scope, var_list, var_types):
287+
from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil
288+
fleet_util = FleetUtil()
289+
for i in range(len(var_list)):
290+
fleet_util.set_zero(
291+
var_list[i].name, scope, param_type=var_types[i])
292+
286293
def get_global_auc(self,
287294
scope=fluid.global_scope(),
288295
stat_pos="_generated_var_2",
@@ -315,6 +322,7 @@ def get_global_auc(self,
315322
global_pos = fleet.util.all_reduce(pos)
316323
# reshape to its original shape
317324
global_pos = global_pos.reshape(old_pos_shape)
325+
print('debug global auc global_pos: ', global_pos)
318326

319327
# auc neg bucket
320328
neg = np.array(scope.find_var(stat_neg).get_tensor())
@@ -323,6 +331,7 @@ def get_global_auc(self,
323331
#global_neg = np.copy(neg) * 0
324332
global_neg = fleet.util.all_reduce(neg)
325333
global_neg = global_neg.reshape(old_neg_shape)
334+
print('debug global auc global_neg: ', global_neg)
326335

327336
# calculate auc
328337
num_bucket = len(global_pos[0])
@@ -615,7 +624,7 @@ def write_xbox_donefile(self,
615624
else:
616625
with open(donefile_name, "w") as f:
617626
f.write(xbox_str + "\n")
618-
client.upad(
627+
client.upload(
619628
donefile_name,
620629
output_path,
621630
multi_processes=1,
@@ -667,6 +676,166 @@ def _get_xbox_str(self,
667676
) + model_path.rstrip("/") + "/000"
668677
return json.dumps(xbox_dict)
669678

679+
def get_global_metrics(self,
680+
scope=fluid.global_scope(),
681+
stat_pos_name="_generated_var_2",
682+
stat_neg_name="_generated_var_3",
683+
sqrerr_name="sqrerr",
684+
abserr_name="abserr",
685+
prob_name="prob",
686+
q_name="q",
687+
pos_ins_num_name="pos",
688+
total_ins_num_name="total"):
689+
from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil
690+
fleet_util = FleetUtil()
691+
if scope.find_var(stat_pos_name) is None or \
692+
scope.find_var(stat_neg_name) is None:
693+
fleet_util.rank0_print("not found auc bucket")
694+
return [None] * 9
695+
elif scope.find_var(sqrerr_name) is None:
696+
fleet_util.rank0_print("not found sqrerr_name=%s" % sqrerr_name)
697+
return [None] * 9
698+
elif scope.find_var(abserr_name) is None:
699+
fleet_util.rank0_print("not found abserr_name=%s" % abserr_name)
700+
return [None] * 9
701+
elif scope.find_var(prob_name) is None:
702+
fleet_util.rank0_print("not found prob_name=%s" % prob_name)
703+
return [None] * 9
704+
elif scope.find_var(q_name) is None:
705+
fleet_util.rank0_print("not found q_name=%s" % q_name)
706+
return [None] * 9
707+
elif scope.find_var(pos_ins_num_name) is None:
708+
fleet_util.rank0_print("not found pos_ins_num_name=%s" %
709+
pos_ins_num_name)
710+
return [None] * 9
711+
elif scope.find_var(total_ins_num_name) is None:
712+
fleet_util.rank0_print("not found total_ins_num_name=%s" % \
713+
total_ins_num_name)
714+
return [None] * 9
715+
716+
# barrier worker to ensure all workers finished training
717+
fleet.barrier_worker()
718+
719+
# get auc
720+
auc = self.get_global_auc(scope, stat_pos_name, stat_neg_name)
721+
pos = np.array(scope.find_var(stat_pos_name).get_tensor())
722+
# auc pos bucket shape
723+
old_pos_shape = np.array(pos.shape)
724+
# reshape to one dim
725+
pos = pos.reshape(-1)
726+
global_pos = np.copy(pos) * 0
727+
# mpi allreduce
728+
# fleet._role_maker._all_reduce(pos, global_pos)
729+
global_pos = fleet.util.all_reduce(pos)
730+
# reshape to its original shape
731+
global_pos = global_pos.reshape(old_pos_shape)
732+
# auc neg bucket
733+
neg = np.array(scope.find_var(stat_neg_name).get_tensor())
734+
old_neg_shape = np.array(neg.shape)
735+
neg = neg.reshape(-1)
736+
global_neg = np.copy(neg) * 0
737+
# fleet._role_maker._all_reduce(neg, global_neg)
738+
global_neg = fleet.util.all_reduce(neg)
739+
global_neg = global_neg.reshape(old_neg_shape)
740+
741+
num_bucket = len(global_pos[0])
742+
743+
def get_metric(name):
744+
metric = np.array(scope.find_var(name).get_tensor())
745+
old_metric_shape = np.array(metric.shape)
746+
metric = metric.reshape(-1)
747+
print(name, 'ori value:', metric)
748+
global_metric = np.copy(metric) * 0
749+
# fleet._role_maker._all_reduce(metric, global_metric)
750+
global_metric = fleet.util.all_reduce(metric)
751+
global_metric = global_metric.reshape(old_metric_shape)
752+
print(name, global_metric)
753+
return global_metric[0]
754+
755+
global_sqrerr = get_metric(sqrerr_name)
756+
global_abserr = get_metric(abserr_name)
757+
global_prob = get_metric(prob_name)
758+
global_q_value = get_metric(q_name)
759+
# note: get ins_num from auc bucket is not actual value,
760+
# so get it from metric op
761+
pos_ins_num = get_metric(pos_ins_num_name)
762+
total_ins_num = get_metric(total_ins_num_name)
763+
neg_ins_num = total_ins_num - pos_ins_num
764+
765+
mae = global_abserr / total_ins_num
766+
rmse = math.sqrt(global_sqrerr / total_ins_num)
767+
return_actual_ctr = pos_ins_num / total_ins_num
768+
predicted_ctr = global_prob / total_ins_num
769+
mean_predict_qvalue = global_q_value / total_ins_num
770+
copc = 0.0
771+
if abs(predicted_ctr > 1e-6):
772+
copc = return_actual_ctr / predicted_ctr
773+
774+
# calculate bucket error
775+
last_ctr = -1.0
776+
impression_sum = 0.0
777+
ctr_sum = 0.0
778+
click_sum = 0.0
779+
error_sum = 0.0
780+
error_count = 0.0
781+
click = 0.0
782+
show = 0.0
783+
ctr = 0.0
784+
adjust_ctr = 0.0
785+
relative_error = 0.0
786+
actual_ctr = 0.0
787+
relative_ctr_error = 0.0
788+
k_max_span = 0.01
789+
k_relative_error_bound = 0.05
790+
for i in range(num_bucket):
791+
click = global_pos[0][i]
792+
show = global_pos[0][i] + global_neg[0][i]
793+
ctr = float(i) / num_bucket
794+
if abs(ctr - last_ctr) > k_max_span:
795+
last_ctr = ctr
796+
impression_sum = 0.0
797+
ctr_sum = 0.0
798+
click_sum = 0.0
799+
impression_sum += show
800+
ctr_sum += ctr * show
801+
click_sum += click
802+
if impression_sum == 0:
803+
continue
804+
adjust_ctr = ctr_sum / impression_sum
805+
if adjust_ctr == 0:
806+
continue
807+
relative_error = \
808+
math.sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum))
809+
if relative_error < k_relative_error_bound:
810+
actual_ctr = click_sum / impression_sum
811+
relative_ctr_error = abs(actual_ctr / adjust_ctr - 1)
812+
error_sum += relative_ctr_error * impression_sum
813+
error_count += impression_sum
814+
last_ctr = -1
815+
816+
bucket_error = error_sum / error_count if error_count > 0 else 0.0
817+
818+
return [
819+
auc, bucket_error, mae, rmse, return_actual_ctr, predicted_ctr,
820+
copc, mean_predict_qvalue, int(total_ins_num)
821+
]
822+
823+
def get_global_metrics_str(self, scope, metric_list, prefix):
824+
if len(metric_list) != 10:
825+
raise ValueError("len(metric_list) != 10, %s" % len(metric_list))
826+
827+
auc, bucket_error, mae, rmse, actual_ctr, predicted_ctr, copc, \
828+
mean_predict_qvalue, total_ins_num = self.get_global_metrics( \
829+
scope, metric_list[2].name, metric_list[3].name, metric_list[4].name, metric_list[5].name, \
830+
metric_list[6].name, metric_list[7].name, metric_list[8].name, metric_list[9].name)
831+
metrics_str = "%s global AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f " \
832+
"RMSE=%.6f Actural_CTR=%.6f Predicted_CTR=%.6f " \
833+
"COPC=%.6f MEAN Q_VALUE=%.6f Ins number=%s" % \
834+
(prefix, auc, bucket_error, mae, rmse, \
835+
actual_ctr, predicted_ctr, copc, mean_predict_qvalue, \
836+
total_ins_num)
837+
return metrics_str
838+
670839
def init_network(self):
671840
self.model = get_model(self.config)
672841
self.input_data = self.model.create_feeds()
@@ -676,6 +845,15 @@ def init_network(self):
676845
self.predict = self.model.predict
677846
self.inference_model_feed_vars = self.model.inference_model_feed_vars
678847
logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
848+
thread_stat_var_names = [
849+
self.model.auc_stat_list[2].name, self.model.auc_stat_list[3].name
850+
]
851+
852+
thread_stat_var_names += [i.name for i in self.model.metric_list]
853+
854+
thread_stat_var_names = list(set(thread_stat_var_names))
855+
856+
self.config['stat_var_names'] = thread_stat_var_names
679857
self.model.create_optimizer(get_strategy(self.config))
680858

681859
def run_server(self):
@@ -697,17 +875,17 @@ def file_ls(self, path_array):
697875
"fs.default.name": self.hadoop_fs_name,
698876
"hadoop.job.ugi": self.hadoop_fs_ugi
699877
}
878+
data_path = self.data_path
700879
hdfs_client = HDFSClient("$HADOOP_HOME", configs)
701880
for i in path_array:
702881
cur_path = hdfs_client.ls_dir(i)[1]
703-
prefix = "hdfs:/user/paddle/" if self.hadoop_fs_name.startswith(
704-
"hdfs:") else "afs:/user/paddle/"
882+
#prefix = "hdfs:" if self.hadoop_fs_name.startswith("hdfs:") else "afs:"
705883
if len(cur_path) > 0:
706884
i = i.strip("/")
707-
result += [
708-
prefix + i.rstrip("/") + "/" + j for j in cur_path
709-
]
710-
result += cur_path
885+
#result += [prefix + i.rstrip("/") + "/" + j for j in cur_path]
886+
result += [i.rstrip("/") + "/" + j for j in cur_path]
887+
# #result += cur_path
888+
# result += [data_path.rstrip("/") + "/" + j for j in cur_path]
711889
logger.info("file ls result = {}".format(result))
712890
return result
713891

@@ -736,7 +914,7 @@ def prepare_dataset(self, day, pass_index):
736914
# dataset, file_list = get_reader(self.input_data, config)
737915

738916
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
739-
#dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
917+
# dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
740918
dataset.set_use_var(self.input_data)
741919
dataset.set_batch_size(self.config.get('runner.train_batch_size'))
742920
dataset.set_thread(self.config.get('runner.train_thread_num', 1))
@@ -753,7 +931,7 @@ def prepare_dataset(self, day, pass_index):
753931
logger.info("my_file_list = {}".format(my_file_list))
754932
dataset.set_filelist(my_file_list)
755933
pipe_command = self.config.get("runner.pipe_command")
756-
#dataset.set_pipe_command(self.config.get("runner.pipe_command"))
934+
# dataset.set_pipe_command(self.config.get("runner.pipe_command"))
757935
utils_path = common.get_utils_file_path()
758936
dataset.set_pipe_command("{} {} {}".format(
759937
pipe_command, config.get("yaml_path"), utils_path))
@@ -839,6 +1017,22 @@ def run_worker(self):
8391017
"Prepare Dataset Done, using time {} second.".format(
8401018
prepare_data_end_time - prepare_data_start_time))
8411019

1020+
set_dump_config(paddle.static.default_main_program(), {
1021+
"dump_fields_path": './test_dump',
1022+
"dump_fields": [
1023+
'sparse_embedding_0.tmp_0@GRAD',
1024+
'sequence_pool_0.tmp_0@GRAD', 'concat_0.tmp_0@GRAD',
1025+
'concat_0.tmp_0', 'linear_6.tmp_1', 'relu_0.tmp_0',
1026+
'linear_7.tmp_1', 'relu_1.tmp_0', 'linear_8.tmp_1',
1027+
'relu_2.tmp_0', 'linear_9.tmp_1', 'relu_3.tmp_0',
1028+
'linear_10.tmp_1', 'relu_4.tmp_0', 'linear_11.tmp_1',
1029+
'sigmoid_0.tmp_0', 'clip_0.tmp_0',
1030+
'sigmoid_0.tmp_0@GRAD', 'clip_0.tmp_0@GRAD',
1031+
'linear_11.tmp_1@GRAD', 'linear_9.tmp_1@GRAD',
1032+
'linear_6.tmp_1@GRAD', 'concat_0.tmp_0@GRAD'
1033+
],
1034+
})
1035+
8421036
train_start_time = time.time()
8431037
train_end_time = time.time()
8441038
logger.info("Train Dataset Done, using time {} second.".format(
@@ -861,6 +1055,21 @@ def run_worker(self):
8611055
dataset.release_memory()
8621056
global_auc = self.get_global_auc()
8631057
logger.info(" global auc %f" % global_auc)
1058+
1059+
metric_list = list(self.model.auc_stat_list) + list(
1060+
self.model.metric_list)
1061+
1062+
metric_types = ["int64"] * len(self.model.auc_stat_list) + [
1063+
"float32"
1064+
] * len(self.model.metric_list)
1065+
1066+
metric_str = self.get_global_metrics_str(
1067+
fluid.global_scope(), metric_list, "update pass:")
1068+
1069+
logger.info(" global metric %s" % metric_str)
1070+
1071+
self.clear_metrics(fluid.global_scope(), metric_list,
1072+
metric_types)
8641073
if fleet.is_first_worker():
8651074
if index % self.checkpoint_per_pass == 0:
8661075
self.save_model(save_model_path, day, index)

0 commit comments

Comments
 (0)