Skip to content

Commit 02bc4f5

Browse files
committed
update escm2
1 parent db2a8cd commit 02bc4f5

File tree

4 files changed

+43
-35
lines changed

4 files changed

+43
-35
lines changed

models/multitask/escm2/config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ runner:
2222
train_batch_size: 2
2323
epochs: 3
2424
print_interval: 2
25-
#model_init_path: "output_model_esmm/2" # init model
26-
model_save_path: "output_model_esmm"
25+
#model_init_path: "output_model_escm/2" # init model
26+
model_save_path: "output_model_escm"
2727
test_data_dir: "data/train"
2828
infer_batch_size: 2
29-
infer_reader_path: "esmm_reader" # importlib format
30-
infer_load_path: "output_model_esmm"
29+
infer_reader_path: "escm_reader" # importlib format
30+
infer_load_path: "output_model_escm"
3131
infer_start_epoch: 0
3232
infer_end_epoch: 3
3333
counterfact_mode: "DR"

models/multitask/escm2/config_bigdata.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ runner:
2323
epochs: 10
2424
print_interval: 10
2525
#model_init_path: "output_model/0" # init model
26-
model_save_path: "output_model_esmm_all"
26+
model_save_path: "output_model_escm_all"
2727
test_data_dir: "../../../datasets/ali-ccp/test_data"
2828
infer_batch_size: 1024
29-
infer_reader_path: "esmm_reader" # importlib format
30-
infer_load_path: "output_model_esmm_all"
29+
infer_reader_path: "escm_reader" # importlib format
30+
infer_load_path: "output_model_escm_all"
3131
infer_start_epoch: 0
3232
infer_end_epoch: 10
3333
counterfact_mode: "DR"

models/multitask/escm2/dygraph_model.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def create_loss(self, ctr_out_one, ctr_clk, ctcvr_prop_one, ctcvr_buy,
7979
ctr_num = paddle.sum(ctr_clk, axis=0)
8080
O = paddle.cast(ctr_clk, 'float32')
8181
if self.counterfact_mode == "DR":
82-
loss_cvr = self.counterfact_dr(loss_cvr, ctr_num, O, ctr_out_one,
82+
loss_cvr = self.counterfact_dr(loss_cvr, O, ctr_out_one,
8383
out_list[6])
8484
else:
8585
loss_cvr = self.counterfact_ipw(loss_cvr, ctr_num, O, ctr_out_one)
@@ -109,39 +109,43 @@ def counterfact_ipw(self, loss_cvr, ctr_num, O, ctr_out_one):
109109
PS = paddle.multiply(
110110
ctr_out_one, paddle.cast(
111111
ctr_num, dtype="float32"))
112-
PS = paddle.multiply(PS, paddle.cast(ctr_num, dtype="float32"))
113112
min_v = paddle.full_like(PS, 0.000001)
114113
PS = paddle.maximum(PS, min_v)
115114
IPS = paddle.reciprocal(PS)
116-
#batch_shape = paddle.full_like(O, 1)
117-
#batch_size = paddle.sum(paddle.cast(batch_shape, dtype="float32"), axis=0)
115+
batch_shape = paddle.full_like(O, 1)
116+
batch_size = paddle.sum(paddle.cast(
117+
batch_shape, dtype="float32"),
118+
axis=0)
118119
#TODO this shoud be a hyparameter
119120
IPS = paddle.clip(IPS, min=-15, max=15) #online trick
120-
#IPS = paddle.multiply(IPS, batch_size)
121+
IPS = paddle.multiply(IPS, batch_size)
121122
IPS.stop_gradient = True
122123
loss_cvr = paddle.multiply(loss_cvr, IPS)
123124
loss_cvr = paddle.multiply(loss_cvr, O)
124-
return loss_cvr
125+
return paddle.mean(loss_cvr)
125126

126-
def counterfact_dr(self, loss_cvr, ctr_num, O, ctr_out_one, imp_out):
127+
def counterfact_dr(self, loss_cvr, O, ctr_out_one, imp_out):
127128
#dr error part
128-
loss_error_first = imp_out
129129
e = paddle.subtract(loss_cvr, imp_out)
130130

131131
min_v = paddle.full_like(ctr_out_one, 0.000001)
132132
ctr_out_one = paddle.maximum(ctr_out_one, min_v)
133+
IPS = paddle.divide(paddle.cast(O, dtype="float32"), ctr_out_one)
133134

134-
loss_error_second = paddle.multiply(O, e)
135-
loss_error_second = paddle.divide(loss_error_second, ctr_out_one)
135+
IPS = paddle.clip(IPS, min=-15, max=15) #online trick
136+
IPS.stop_gradient = True
136137

137-
loss_error = loss_error_first + loss_error_second
138+
loss_error_second = paddle.multiply(e, IPS)
139+
140+
loss_error = imp_out + loss_error_second
138141

139142
#dr imp part
140143
loss_imp = paddle.square(e)
141-
loss_imp = paddle.multiply(loss_imp, O)
142-
loss_imp = paddle.divide(loss_imp, ctr_out_one)
144+
loss_imp = paddle.multiply(loss_imp, IPS)
145+
146+
loss_dr = loss_error + loss_imp
143147

144-
return loss_error + loss_imp
148+
return paddle.mean(loss_dr)
145149

146150
# construct train forward phase
147151
def train_forward(self, dy_model, metrics_list, batch_data, config):

models/multitask/escm2/static_model.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,39 +65,43 @@ def counterfact_ipw(self, loss_cvr, ctr_num, O, ctr_out_one):
6565
PS = paddle.multiply(
6666
ctr_out_one, paddle.cast(
6767
ctr_num, dtype="float32"))
68-
PS = paddle.multiply(PS, paddle.cast(ctr_num, dtype="float32"))
6968
min_v = paddle.full_like(PS, 0.000001)
7069
PS = paddle.maximum(PS, min_v)
7170
IPS = paddle.reciprocal(PS)
72-
#batch_shape = paddle.full_like(O, 1)
73-
#batch_size = paddle.sum(paddle.cast(batch_shape, dtype="float32"), axis=0)
71+
batch_shape = paddle.full_like(O, 1)
72+
batch_size = paddle.sum(paddle.cast(
73+
batch_shape, dtype="float32"),
74+
axis=0)
7475
#TODO this shoud be a hyparameter
7576
IPS = paddle.clip(IPS, min=-15, max=15) #online trick
76-
#IPS = paddle.multiply(IPS, batch_size)
77+
IPS = paddle.multiply(IPS, batch_size)
7778
IPS.stop_gradient = True
7879
loss_cvr = paddle.multiply(loss_cvr, IPS)
7980
loss_cvr = paddle.multiply(loss_cvr, O)
80-
return loss_cvr
81+
return paddle.mean(loss_cvr)
8182

82-
def counterfact_dr(self, loss_cvr, ctr_num, O, ctr_out_one, imp_out):
83+
def counterfact_dr(self, loss_cvr, O, ctr_out_one, imp_out):
8384
#dr error part
84-
loss_error_first = imp_out
8585
e = paddle.subtract(loss_cvr, imp_out)
8686

8787
min_v = paddle.full_like(ctr_out_one, 0.000001)
8888
ctr_out_one = paddle.maximum(ctr_out_one, min_v)
89+
IPS = paddle.divide(paddle.cast(O, dtype="float32"), ctr_out_one)
8990

90-
loss_error_second = paddle.multiply(O, e)
91-
loss_error_second = paddle.divide(loss_error_second, ctr_out_one)
91+
IPS = paddle.clip(IPS, min=-15, max=15) #online trick
92+
IPS.stop_gradient = True
9293

93-
loss_error = loss_error_first + loss_error_second
94+
loss_error_second = paddle.multiply(e, IPS)
95+
96+
loss_error = imp_out + loss_error_second
9497

9598
#dr imp part
9699
loss_imp = paddle.square(e)
97-
loss_imp = paddle.multiply(loss_imp, O)
98-
loss_imp = paddle.divide(loss_imp, ctr_out_one)
100+
loss_imp = paddle.multiply(loss_imp, IPS)
101+
102+
loss_dr = loss_error + loss_imp
99103

100-
return loss_error + loss_imp
104+
return paddle.mean(loss_dr)
101105

102106
def net(self, inputs, is_infer=False):
103107

@@ -138,7 +142,7 @@ def net(self, inputs, is_infer=False):
138142
input=cvr_out_one, label=paddle.cast(
139143
ctcvr_buy, dtype="float32"))
140144
if self.counterfact_mode == "DR":
141-
loss_cvr = self.counterfact_dr(loss_cvr, ctr_num, O, ctr_out_one,
145+
loss_cvr = self.counterfact_dr(loss_cvr, O, ctr_out_one,
142146
out_list[6])
143147
else:
144148
loss_cvr = self.counterfact_ipw(loss_cvr, ctr_num, O, ctr_out_one)

0 commit comments

Comments
 (0)