|
| 1 | +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import paddle |
| 16 | +import paddle.nn as nn |
| 17 | +import paddle.nn.functional as F |
| 18 | +import math |
| 19 | + |
| 20 | +import net |
| 21 | + |
| 22 | + |
| 23 | +class DygraphModel(): |
| 24 | + # define model |
| 25 | + def create_model(self, config): |
| 26 | + max_len = config.get("hyper_parameters.max_len", 3) |
| 27 | + sparse_feature_number = config.get( |
| 28 | + "hyper_parameters.sparse_feature_number") |
| 29 | + self.global_w = config.get("hyper_parameters.global_w", 0.5) |
| 30 | + self.counterfactual_w = config.get("hyper_parameters.counterfactual_w", |
| 31 | + 0.5) |
| 32 | + sparse_feature_dim = config.get("hyper_parameters.sparse_feature_dim") |
| 33 | + num_field = config.get("hyper_parameters.num_field") |
| 34 | + learning_rate = config.get("hyper_parameters.optimizer.learning_rate") |
| 35 | + ctr_fc_sizes = config.get("hyper_parameters.ctr_fc_sizes") |
| 36 | + cvr_fc_sizes = config.get("hyper_parameters.cvr_fc_sizes") |
| 37 | + sparse_feature_number = config.get( |
| 38 | + "hyper_parameters.sparse_feature_number") |
| 39 | + expert_num = config.get("hyper_parameters.expert_num") |
| 40 | + self.counterfact_mode = config.get("runner.counterfact_mode") |
| 41 | + expert_size = config.get("hyper_parameters.expert_size") |
| 42 | + tower_size = config.get("hyper_parameters.tower_size") |
| 43 | + feature_size = config.get("hyper_parameters.feature_size") |
| 44 | + |
| 45 | + escm_model = net.ESCMLayer(sparse_feature_number, sparse_feature_dim, |
| 46 | + num_field, ctr_fc_sizes, cvr_fc_sizes, |
| 47 | + expert_num, expert_size, tower_size, |
| 48 | + self.counterfact_mode, feature_size) |
| 49 | + |
| 50 | + return escm_model |
| 51 | + |
| 52 | + # define feeds which convert numpy of batch data to paddle.tensor |
| 53 | + def create_feeds(self, batch_data, config): |
| 54 | + max_len = config.get("hyper_parameters.max_len", 3) |
| 55 | + sparse_tensor = [] |
| 56 | + for b in batch_data[:-2]: |
| 57 | + sparse_tensor.append( |
| 58 | + paddle.to_tensor(b.numpy().astype('int64').reshape(-1, |
| 59 | + max_len))) |
| 60 | + ctr_label = paddle.to_tensor(batch_data[-2].numpy().astype('int64') |
| 61 | + .reshape(-1, 1)) |
| 62 | + ctcvr_label = paddle.to_tensor(batch_data[-1].numpy().astype('int64') |
| 63 | + .reshape(-1, 1)) |
| 64 | + return sparse_tensor, ctr_label, ctcvr_label |
| 65 | + |
| 66 | + # define loss function by predicts and label |
| 67 | + def create_loss(self, ctr_out_one, ctr_clk, ctcvr_prop_one, ctcvr_buy, |
| 68 | + cvr_out_one, out_list): |
| 69 | + loss_ctr = paddle.nn.functional.log_loss( |
| 70 | + input=ctr_out_one, label=paddle.cast( |
| 71 | + ctr_clk, dtype="float32")) |
| 72 | + loss_cvr = paddle.nn.functional.log_loss( |
| 73 | + input=cvr_out_one, label=paddle.cast( |
| 74 | + ctcvr_buy, dtype="float32")) |
| 75 | + loss_ctcvr = paddle.nn.functional.log_loss( |
| 76 | + input=ctcvr_prop_one, |
| 77 | + label=paddle.cast( |
| 78 | + ctcvr_buy, dtype="float32")) |
| 79 | + ctr_num = paddle.sum(ctr_clk, axis=0) |
| 80 | + O = paddle.cast(ctr_clk, 'float32') |
| 81 | + if self.counterfact_mode == "DR": |
| 82 | + loss_cvr = self.counterfact_dr(loss_cvr, ctr_num, O, ctr_out_one, |
| 83 | + out_list[6]) |
| 84 | + else: |
| 85 | + loss_cvr = self.counterfact_ipw(loss_cvr, ctr_num, O, ctr_out_one) |
| 86 | + |
| 87 | + cost = loss_ctr + loss_cvr * self.counterfactual_w + loss_ctcvr * self.global_w |
| 88 | + avg_cost = paddle.mean(x=cost) |
| 89 | + return avg_cost |
| 90 | + |
| 91 | + # define optimizer |
| 92 | + def create_optimizer(self, dy_model, config): |
| 93 | + lr = config.get("hyper_parameters.optimizer.learning_rate", 0.001) |
| 94 | + optimizer = paddle.optimizer.Adam( |
| 95 | + learning_rate=lr, parameters=dy_model.parameters()) |
| 96 | + return optimizer |
| 97 | + |
| 98 | + # define metrics such as auc/acc |
| 99 | + # multi-task need to define multi metric |
| 100 | + def create_metrics(self): |
| 101 | + metrics_list_name = ["auc_ctr", "auc_cvr", "auc_ctcvr"] |
| 102 | + auc_ctr_metric = paddle.metric.Auc("ROC") |
| 103 | + auc_cvr_metric = paddle.metric.Auc("ROC") |
| 104 | + auc_ctcvr_metric = paddle.metric.Auc("ROC") |
| 105 | + metrics_list = [auc_ctr_metric, auc_cvr_metric, auc_ctcvr_metric] |
| 106 | + return metrics_list, metrics_list_name |
| 107 | + |
| 108 | + def counterfact_ipw(self, loss_cvr, ctr_num, O, ctr_out_one): |
| 109 | + PS = paddle.multiply( |
| 110 | + ctr_out_one, paddle.cast( |
| 111 | + ctr_num, dtype="float32")) |
| 112 | + PS = paddle.multiply(PS, paddle.cast(ctr_num, dtype="float32")) |
| 113 | + min_v = paddle.full_like(PS, 0.000001) |
| 114 | + PS = paddle.maximum(PS, min_v) |
| 115 | + IPS = paddle.reciprocal(PS) |
| 116 | + #batch_shape = paddle.full_like(O, 1) |
| 117 | + #batch_size = paddle.sum(paddle.cast(batch_shape, dtype="float32"), axis=0) |
| 118 | + #TODO this shoud be a hyparameter |
| 119 | + IPS = paddle.clip(IPS, min=-15, max=15) #online trick |
| 120 | + #IPS = paddle.multiply(IPS, batch_size) |
| 121 | + IPS.stop_gradient = True |
| 122 | + loss_cvr = paddle.multiply(loss_cvr, IPS) |
| 123 | + loss_cvr = paddle.multiply(loss_cvr, O) |
| 124 | + return loss_cvr |
| 125 | + |
| 126 | + def counterfact_dr(self, loss_cvr, ctr_num, O, ctr_out_one, imp_out): |
| 127 | + #dr error part |
| 128 | + loss_error_first = imp_out |
| 129 | + e = paddle.subtract(loss_cvr, imp_out) |
| 130 | + |
| 131 | + min_v = paddle.full_like(ctr_out_one, 0.000001) |
| 132 | + ctr_out_one = paddle.maximum(ctr_out_one, min_v) |
| 133 | + |
| 134 | + loss_error_second = paddle.multiply(O, e) |
| 135 | + loss_error_second = paddle.divide(loss_error_second, ctr_out_one) |
| 136 | + |
| 137 | + loss_error = loss_error_first + loss_error_second |
| 138 | + |
| 139 | + #dr imp part |
| 140 | + loss_imp = paddle.square(e) |
| 141 | + loss_imp = paddle.multiply(loss_imp, O) |
| 142 | + loss_imp = paddle.divide(loss_imp, ctr_out_one) |
| 143 | + |
| 144 | + return loss_error + loss_imp |
| 145 | + |
| 146 | + # construct train forward phase |
| 147 | + def train_forward(self, dy_model, metrics_list, batch_data, config): |
| 148 | + sparse_tensor, label_ctr, label_ctcvr = self.create_feeds(batch_data, |
| 149 | + config) |
| 150 | + |
| 151 | + out_list = dy_model.forward(sparse_tensor) |
| 152 | + ctr_out, ctr_out_one, cvr_out, cvr_out_one, ctcvr_prop, ctcvr_prop_one = out_list[ |
| 153 | + 0], out_list[1], out_list[2], out_list[3], out_list[4], out_list[5] |
| 154 | + loss = self.create_loss(ctr_out_one, label_ctr, ctcvr_prop_one, |
| 155 | + label_ctcvr, cvr_out_one, out_list) |
| 156 | + # update metrics |
| 157 | + metrics_list[0].update(preds=ctr_out.numpy(), labels=label_ctr.numpy()) |
| 158 | + metrics_list[1].update( |
| 159 | + preds=cvr_out.numpy(), labels=label_ctcvr.numpy()) |
| 160 | + metrics_list[2].update( |
| 161 | + preds=ctcvr_prop.numpy(), labels=label_ctcvr.numpy()) |
| 162 | + |
| 163 | + # print_dict format :{'loss': loss} |
| 164 | + print_dict = {'loss': loss} |
| 165 | + return loss, metrics_list, print_dict |
| 166 | + |
| 167 | + def infer_forward(self, dy_model, metrics_list, batch_data, config): |
| 168 | + sparse_tensor, label_ctr, label_ctcvr = self.create_feeds(batch_data, |
| 169 | + config) |
| 170 | + |
| 171 | + ctr_out, ctr_out_one, cvr_out, cvr_out_one, ctcvr_prop, ctcvr_prop_one, D = dy_model.forward( |
| 172 | + sparse_tensor) |
| 173 | + # update metrics |
| 174 | + metrics_list[0].update(preds=ctr_out.numpy(), labels=label_ctr.numpy()) |
| 175 | + metrics_list[1].update( |
| 176 | + preds=cvr_out.numpy(), labels=label_ctcvr.numpy()) |
| 177 | + metrics_list[2].update( |
| 178 | + preds=ctcvr_prop.numpy(), labels=label_ctcvr.numpy()) |
| 179 | + |
| 180 | + return metrics_list, None |
0 commit comments