Skip to content

Commit 96a93ae

Browse files
committed
upgrade zero inflated lognormal loss, support export structure path, upgrade pre-commit hooks
1 parent 76b9f10 commit 96a93ae

File tree

179 files changed

+372
-146
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

179 files changed

+372
-146
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@ repos:
66
additional_dependencies: [
77
'flake8-docstrings==1.5.0'
88
]
9-
- repo: https://github.com/pycqa/isort
10-
rev: 5.12.0
9+
- repo: https://github.com/asottile/seed-isort-config
10+
rev: v2.2.0
11+
hooks:
12+
- id: seed-isort-config
13+
- repo: https://github.com/timothycrosley/isort
14+
rev: 4.3.21
1115
hooks:
1216
- id: isort
1317
- repo: https://github.com/pre-commit/mirrors-yapf

docs/source/_ext/post_process.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
23
from docutils import nodes
34
from docutils.transforms import Transform
45

docs/source/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# documentation root, use os.path.abspath to make it absolute, like shown here.
1414
#
1515
import os
16-
import sphinx_rtd_theme
1716
import sys
1817

18+
import sphinx_rtd_theme
19+
1920
import easy_rec
2021

2122
# sys.path.insert(0, os.path.abspath('.'))

easy_rec/python/builders/loss_builder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
# -*- encoding:utf-8 -*-
22
# Copyright (c) Alibaba, Inc. and its affiliates.
33
import logging
4+
45
import numpy as np
56
import tensorflow as tf
67

7-
from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA
88
from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
99
from easy_rec.python.loss.jrc_loss import jrc_loss
10+
from easy_rec.python.protos.loss_pb2 import LossType
11+
12+
from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA
1013
from easy_rec.python.loss.listwise_loss import listwise_distill_loss, listwise_rank_loss # NOQA
1114
from easy_rec.python.loss.pairwise_loss import pairwise_focal_loss, pairwise_hinge_loss, pairwise_logistic_loss, pairwise_loss # NOQA
1215
from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA
13-
from easy_rec.python.protos.loss_pb2 import LossType
1416

1517
if tf.__version__ >= '2.0':
1618
tf = tf.compat.v1

easy_rec/python/builders/optimizer_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# ==============================================================================
1616
"""Functions to build training optimizers."""
1717
import logging
18+
1819
import tensorflow as tf
1920

2021
from easy_rec.python.compat import weight_decay_optimizers

easy_rec/python/compat/adam_s.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
from tensorflow.python.eager import context
1919
from tensorflow.python.framework import ops
20-
from tensorflow.python.ops import array_ops, control_flow_ops, math_ops, resource_variable_ops, state_ops # NOQA
2120
from tensorflow.python.training import optimizer, training_ops
2221

22+
from tensorflow.python.ops import array_ops, control_flow_ops, math_ops, resource_variable_ops, state_ops # NOQA
23+
2324

2425
class AdamOptimizerS(optimizer.Optimizer):
2526
"""Optimizer that implements the Adam algorithm.

easy_rec/python/compat/dynamic_variable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
#
1616

1717
import json
18+
1819
import tensorflow as tf
1920
from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops
2021
from sparse_operation_kit.experiment.communication import num_gpus
2122
from tensorflow.python.eager import context
2223
from tensorflow.python.framework import ops
2324
# from tensorflow.python.ops import array_ops
2425
from tensorflow.python.ops import resource_variable_ops
26+
2527
from tensorflow.python.ops.resource_variable_ops import ResourceVariable, variable_accessed # NOQA
2628

2729
# from tensorflow.python.util import object_identity

easy_rec/python/compat/early_stopping.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,21 @@
1919
import logging
2020
import operator
2121
import os
22-
import tensorflow as tf
2322
import threading
2423
import time
24+
25+
import tensorflow as tf
2526
from distutils.version import LooseVersion
2627
from tensorflow.python.framework import dtypes, ops
2728
from tensorflow.python.ops import init_ops, state_ops, variable_scope
2829
from tensorflow.python.platform import gfile, tf_logging
2930
from tensorflow.python.summary import summary_iterator
30-
from tensorflow.python.training import basic_session_run_hooks, session_run_hook, training_util # NOQA
3131

3232
from easy_rec.python.utils.config_util import parse_time
3333
from easy_rec.python.utils.load_class import load_by_path
3434

35+
from tensorflow.python.training import basic_session_run_hooks, session_run_hook, training_util # NOQA
36+
3537
if LooseVersion(tf.__version__) >= LooseVersion('2.12.0'):
3638
from tensorflow_estimator.python.estimator.estimator_export import estimator_export # NOQA
3739
else:

easy_rec/python/compat/embedding_parallel_saver.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
# -*- encoding:utf-8 -*-
22

33
import logging
4-
import numpy as np
54
import os
5+
6+
import numpy as np
67
from tensorflow.core.protobuf import saver_pb2
78
from tensorflow.python.framework import dtypes, ops
8-
# from tensorflow.python.ops import math_ops
9-
# from tensorflow.python.ops import logging_ops
10-
from tensorflow.python.ops import array_ops, control_flow_ops, script_ops, state_ops # NOQA
119
from tensorflow.python.platform import gfile
1210
from tensorflow.python.training import saver
1311

1412
from easy_rec.python.utils import constant
1513

14+
# from tensorflow.python.ops import math_ops
15+
# from tensorflow.python.ops import logging_ops
16+
from tensorflow.python.ops import array_ops, control_flow_ops, script_ops, state_ops # NOQA
17+
1618
try:
1719
import horovod.tensorflow as hvd
1820
from sparse_operation_kit.experiment import raw_ops as dynamic_variable_ops

easy_rec/python/compat/estimator_train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
# Copyright (c) Alibaba, Inc. and its affiliates.
33
import logging
44
import os
5+
56
import tensorflow as tf
6-
from tensorflow.python.distribute import estimator_training as distribute_coordinator_training # NOQA
77
from tensorflow.python.estimator import run_config as run_config_lib
8-
from tensorflow.python.estimator.training import _assert_eval_spec, _ContinuousEvalListener, _TrainingExecutor # NOQA
98
from tensorflow.python.util import compat
109

1110
from easy_rec.python.compat.exporter import FinalExporter
1211
from easy_rec.python.utils import estimator_utils
1312

13+
from tensorflow.python.distribute import estimator_training as distribute_coordinator_training # NOQA
14+
from tensorflow.python.estimator.training import _assert_eval_spec, _ContinuousEvalListener, _TrainingExecutor # NOQA
15+
1416
if tf.__version__ >= '2.0':
1517
tf = tf.compat.v1
1618
gfile = tf.gfile

0 commit comments

Comments
 (0)