Skip to content

Commit 6467e0c

Browse files
committed
fix metric_helper bug
1 parent 7139b6c commit 6467e0c

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

tools/utils/static_ps/metric_helper.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,30 +109,27 @@ def get_global_metrics(scope=fluid.global_scope(),
109109
q_name="q",
110110
pos_ins_num_name="pos",
111111
total_ins_num_name="total"):
112-
from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil
113-
fleet_util = FleetUtil()
114112
if scope.find_var(stat_pos_name) is None or \
115113
scope.find_var(stat_neg_name) is None:
116-
fleet_util.rank0_print("not found auc bucket")
114+
logger.info("not found auc bucket")
117115
return [None] * 9
118116
elif scope.find_var(sqrerr_name) is None:
119-
fleet_util.rank0_print("not found sqrerr_name=%s" % sqrerr_name)
117+
logger.info("not found sqrerr_name=%s" % sqrerr_name)
120118
return [None] * 9
121119
elif scope.find_var(abserr_name) is None:
122-
fleet_util.rank0_print("not found abserr_name=%s" % abserr_name)
120+
logger.info("not found abserr_name=%s" % abserr_name)
123121
return [None] * 9
124122
elif scope.find_var(prob_name) is None:
125-
fleet_util.rank0_print("not found prob_name=%s" % prob_name)
123+
logger.info("not found prob_name=%s" % prob_name)
126124
return [None] * 9
127125
elif scope.find_var(q_name) is None:
128-
fleet_util.rank0_print("not found q_name=%s" % q_name)
126+
logger.info("not found q_name=%s" % q_name)
129127
return [None] * 9
130128
elif scope.find_var(pos_ins_num_name) is None:
131-
fleet_util.rank0_print("not found pos_ins_num_name=%s" %
132-
pos_ins_num_name)
129+
logger.info("not found pos_ins_num_name=%s" % pos_ins_num_name)
133130
return [None] * 9
134131
elif scope.find_var(total_ins_num_name) is None:
135-
fleet_util.rank0_print("not found total_ins_num_name=%s" % \
132+
logger.info("not found total_ins_num_name=%s" % \
136133
total_ins_num_name)
137134
return [None] * 9
138135

@@ -261,8 +258,28 @@ def get_global_metrics_str(scope, metric_list, prefix):
261258
return metrics_str
262259

263260

261+
def set_zero(var_name,
262+
scope=fluid.global_scope(),
263+
place=fluid.CPUPlace(),
264+
param_type="int64"):
265+
"""
266+
Set tensor of a Variable to zero.
267+
268+
Args:
269+
var_name(str): name of Variable
270+
scope(Scope): Scope object, default is fluid.global_scope()
271+
place(Place): Place object, default is fluid.CPUPlace()
272+
param_type(str): param data type, default is int64
273+
274+
Examples:
275+
set_zero(myvar.name, myscope)
276+
277+
"""
278+
param = scope.var(var_name).get_tensor()
279+
param_array = np.zeros(param._get_dims()).astype(param_type)
280+
param.set(param_array, place)
281+
282+
264283
def clear_metrics(scope, var_list, var_types):
265-
from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil
266-
fleet_util = FleetUtil()
267284
for i in range(len(var_list)):
268-
fleet_util.set_zero(var_list[i].name, scope, param_type=var_types[i])
285+
set_zero(var_list[i].name, scope, param_type=var_types[i])

0 commit comments

Comments
 (0)