@@ -119,10 +119,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
119
119
helper = LayerHelper ("auc" , ** locals ())
120
120
auc_out = helper .create_tmp_variable (dtype = "float64" )
121
121
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
122
- tp = helper .create_global_variable (persistable = True , dtype = 'int64' )
123
- tn = helper .create_global_variable (persistable = True , dtype = 'int64' )
124
- fp = helper .create_global_variable (persistable = True , dtype = 'int64' )
125
- fn = helper .create_global_variable (persistable = True , dtype = 'int64' )
122
+ tp = helper .create_global_variable (
123
+ persistable = True , dtype = 'int64' , shape = [num_thresholds ])
124
+ tn = helper .create_global_variable (
125
+ persistable = True , dtype = 'int64' , shape = [num_thresholds ])
126
+ fp = helper .create_global_variable (
127
+ persistable = True , dtype = 'int64' , shape = [num_thresholds ])
128
+ fn = helper .create_global_variable (
129
+ persistable = True , dtype = 'int64' , shape = [num_thresholds ])
126
130
for var in [tp , tn , fp , fn ]:
127
131
helper .set_variable_initializer (
128
132
var , Constant (
0 commit comments