Skip to content

Commit 4697f88

Browse files
committed
upgrade zero inflated lognormal loss, support export structure path, upgrade pre-commit hooks
1 parent 96a93ae commit 4697f88

File tree

3 files changed

+13
-15
lines changed

3 files changed

+13
-15
lines changed

.github/workflows/code_style.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@ jobs:
1616
ref: ${{ github.event.pull_request.head.sha }}
1717
submodules: recursive
1818

19-
- name: Clean pre-commit
20-
run: |
21-
pre-commit --version
22-
pre-commit clean || true
23-
pre-commit gc || true
24-
2519
- name: RunCiTest
2620
id: run_ci_test
2721
env:

.pre-commit-config.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,7 @@ repos:
66
additional_dependencies: [
77
'flake8-docstrings==1.5.0'
88
]
9-
- repo: https://github.com/asottile/seed-isort-config
10-
rev: v2.2.0
11-
hooks:
12-
- id: seed-isort-config
13-
- repo: https://github.com/timothycrosley/isort
9+
- repo: https://github.com/pycqa/isort
1410
rev: 4.3.21
1511
hooks:
1612
- id: isort

easy_rec/python/model/deepfm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,25 @@ def build_input_layer(self, model_config, feature_configs):
4949
if not has_final:
5050
assert model_config.deepfm.wide_output_dim == model_config.num_class
5151
self._wide_output_dim = model_config.deepfm.wide_output_dim
52-
if self._wide_output_dim != 1:
52+
if self._wide_output_dim != self._num_class:
5353
logging.warning(
5454
'wide_output_dim not equal to 1, it is not a standard model'
5555
)
5656
super(DeepFM, self).build_input_layer(model_config, feature_configs)
5757

5858
def build_predict_graph(self):
5959
# Wide
60-
wide_fea = tf.reduce_sum(
61-
self._wide_features, axis=1, keepdims=True, name='wide_feature'
62-
)
60+
if self._num_class > 1 and self._wide_output_dim == self._num_class:
61+
wide_shape = tf.shape(self._wide_features)
62+
new_shape = tf.stack(
63+
[-1, wide_shape[1] // self._num_class, self._num_class]
64+
)
65+
wide_fea = tf.reshape(self._wide_features, new_shape)
66+
wide_fea = tf.reduce_sum(wide_fea, axis=1, name='wide_feature')
67+
else:
68+
wide_fea = tf.reduce_sum(
69+
self._wide_features, axis=1, keepdims=True, name='wide_feature'
70+
)
6371

6472
# FM
6573
fm_fea = fm.FM(name='fm_feature')(self._fm_features)

0 commit comments

Comments
 (0)