Skip to content

Commit 745a201

Browse files
committed
Fix bug of CategoricalFocalLoss #20
1 parent 6d54926 commit 745a201

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

deeptables/models/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ def __init__(self, gamma=2., alpha=.25, reduction=losses.Reduction.AUTO, name='f
986986
gamma {float} -- (default: {2.0})
987987
alpha {float} -- (default: {4.0})
988988
"""
989-
super(BinaryFocalLoss, self).__init__(reduction=reduction, name=name)
989+
super(CategoricalFocalLoss, self).__init__(reduction=reduction, name=name)
990990
self.gamma = float(gamma)
991991
self.alpha = float(alpha)
992992

deeptables/preprocessing/transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from sklearn.utils import column_or_1d
88
from ..utils import dt_logging, consts
99

10+
from sklearn.pipeline import Pipeline
11+
from sklearn.base import BaseEstimator,TransformerMixin
12+
from sklearn.pipeline import FeatureUnion
13+
1014
logger = dt_logging.get_logger()
1115

1216

0 commit comments

Comments
 (0)