Skip to content

Commit a167b70

Browse files
authored
Merge pull request #439 from yinhaofeng/collective_train_2
fix static auc
2 parents e0a7ba1 + 149c7f1 commit a167b70

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

tools/static_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def main(args):
131131

132132
epoch_begin = time.time()
133133
if use_auc:
134-
reset_auc(auc_num)
134+
reset_auc(use_fleet, auc_num)
135135
if reader_type == 'DataLoader':
136136
fetch_batch_var, step_num = dataloader_train(
137137
epoch_id, train_dataloader, input_data_names, fetch_vars, exe,
@@ -234,6 +234,7 @@ def dataloader_train(epoch_id, train_dataloader, input_data_names, fetch_vars,
234234
program=paddle.static.default_main_program(),
235235
feed=dict(zip(input_data_names, batch_data)),
236236
fetch_list=[var for _, var in fetch_vars.items()])
237+
# print(paddle.fluid.global_scope().find_var("_generated_var_2").get_tensor())
237238
train_run_cost += time.time() - train_start
238239
total_samples += batch_size
239240
if batch_id % print_interval == 0:

tools/utils/utils_single.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def load_yaml(yaml_file, other_part=None):
130130
return running_config
131131

132132

133-
def reset_auc(auc_num=1):
133+
def reset_auc(use_fleet=False, auc_num=1):
134134
# for static clear auc
135135
auc_var_name = []
136136
for i in range(auc_num * 4):
@@ -143,5 +143,9 @@ def reset_auc(auc_num=1):
143143
tensor = param.get_tensor()
144144
if param:
145145
tensor_array = np.zeros(tensor._get_dims()).astype("int64")
146-
tensor.set(tensor_array, paddle.CPUPlace())
146+
if use_fleet:
147+
trainer_id = paddle.distributed.get_rank()
148+
tensor.set(tensor_array, paddle.CUDAPlace(trainer_id))
149+
else:
150+
tensor.set(tensor_array, paddle.CPUPlace())
147151
logger.info("AUC Reset To Zero: {}".format(name))

0 commit comments

Comments
 (0)