Skip to content

Commit 149c7f1

Browse files
committed
fix static auc
1 parent 845920b commit 149c7f1

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

tools/static_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def main(args):
126126

127127
epoch_begin = time.time()
128128
if use_auc:
129-
reset_auc(auc_num)
129+
reset_auc(use_fleet, auc_num)
130130
if reader_type == 'DataLoader':
131131
fetch_batch_var, step_num = dataloader_train(
132132
epoch_id, train_dataloader, input_data_names, fetch_vars, exe,
@@ -229,6 +229,7 @@ def dataloader_train(epoch_id, train_dataloader, input_data_names, fetch_vars,
229229
program=paddle.static.default_main_program(),
230230
feed=dict(zip(input_data_names, batch_data)),
231231
fetch_list=[var for _, var in fetch_vars.items()])
232+
# print(paddle.fluid.global_scope().find_var("_generated_var_2").get_tensor())
232233
train_run_cost += time.time() - train_start
233234
total_samples += batch_size
234235
if batch_id % print_interval == 0:

tools/utils/utils_single.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def load_yaml(yaml_file, other_part=None):
130130
return running_config
131131

132132

133-
def reset_auc(auc_num=1):
133+
def reset_auc(use_fleet=False, auc_num=1):
134134
# for static clear auc
135135
auc_var_name = []
136136
for i in range(auc_num * 4):
@@ -143,5 +143,9 @@ def reset_auc(auc_num=1):
143143
tensor = param.get_tensor()
144144
if param:
145145
tensor_array = np.zeros(tensor._get_dims()).astype("int64")
146-
tensor.set(tensor_array, paddle.CPUPlace())
146+
if use_fleet:
147+
trainer_id = paddle.distributed.get_rank()
148+
tensor.set(tensor_array, paddle.CUDAPlace(trainer_id))
149+
else:
150+
tensor.set(tensor_array, paddle.CPUPlace())
147151
logger.info("AUC Reset To Zero: {}".format(name))

0 commit comments

Comments
 (0)