Skip to content

Commit d774f05

Browse files
committed
modify dygraph_model.py
1 parent cb7dbd9 commit d774f05

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

models/multitask/metaheac/dygraph_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ def create_model(self, config):
3333

3434
meta_model = net.WideAndDeepModel(max_idxs, embed_dim, mlp_dims,
3535
num_expert, num_output)
36-
# model_state_dict = paddle.load('paddle.pkl')
37-
# meta_model.set_dict(model_state_dict)
3836

3937
return meta_model
4038

@@ -118,6 +116,10 @@ def train_forward(self, dy_model, metric_list, batch, config):
118116
loss_q = criterion(query_set_y_pred, label)
119117
losses_q.append(loss_q) # Save the loss on the subtask dataset
120118

119+
pred = paddle.unsqueeze(query_set_y_pred, 1)
120+
pred = paddle.concat([1 - pred, pred], 1)
121+
metric_list[0].update(preds=pred.numpy(), labels=label.numpy())
122+
121123
loss_average = paddle.stack(losses_q).mean(0)
122124
print_dict = {'loss': loss_average}
123125

0 commit comments

Comments
 (0)