Skip to content

Commit b3ae973

Browse files
committed
fix
1 parent 5978804 commit b3ae973

File tree

3 files changed

+4
-6
lines changed

3 files changed

+4
-6
lines changed

models/multitask/mmoe/census_reader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ def __iter__(self):
4545
label_marital = [1]
4646
output_list = []
4747
output_list.append(np.array(data).astype('float32'))
48-
output_list.append(
49-
np.array(label_income).astype('float32'))
50-
output_list.append(
51-
np.array(label_marital).astype('float32'))
48+
output_list.append(np.array(label_income).astype('int64'))
49+
output_list.append(np.array(label_marital).astype('int64'))
5250
yield output_list

models/rank/logistic_regression/criteo_lr_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,6 @@ def __iter__(self):
7070
else:
7171
self.visit[slot] = False
7272
# label, feat_idx, feat_value
73-
yield np.array(output[0][1]).astype('int32'), np.array(
73+
yield np.array(output[0][1]).astype('int64'), np.array(
7474
output[1][1]).astype('int64'), np.array(output[2][
7575
1]).astype('float32')

models/rank/logistic_regression/dygraph_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def create_model(self, config):
3535
# define feeds which convert numpy of batch data to paddle.tensor
3636
def create_feeds(self, batch_data, config):
3737
num_field = config.get('hyper_parameters.num_field', None)
38-
label = paddle.to_tensor(batch_data[0].numpy().astype('int32').reshape(
38+
label = paddle.to_tensor(batch_data[0].numpy().astype('int64').reshape(
3939
-1, 1))
4040
feat_idx = paddle.to_tensor(batch_data[1].numpy().astype('int64')
4141
.reshape(-1, 1))

0 commit comments

Comments
 (0)