Skip to content

Commit ef20863

Browse files
authored
fix ernie-encoder (#4281)
1 parent 5d542a2 commit ef20863

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

paddlenlp/transformers/semantic_search/modeling.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,25 @@
1616
import paddle.nn as nn
1717
import paddle.nn.functional as F
1818

19-
from ..ernie.modeling import ErniePretrainedModel
19+
from ..ernie.configuration import ErnieConfig
20+
from ..ernie.modeling import ErnieModel, ErniePretrainedModel
2021

2122
__all__ = ["ErnieDualEncoder", "ErnieCrossEncoder"]
2223

2324

2425
class ErnieEncoder(ErniePretrainedModel):
25-
def __init__(self, ernie, dropout=None, output_emb_size=None, num_classes=2):
26-
super(ErnieEncoder, self).__init__()
27-
self.ernie = ernie # allow ernie to be config
28-
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
29-
self.classifier = nn.Linear(self.ernie.config["hidden_size"], num_classes)
26+
def __init__(self, config: ErnieConfig, output_emb_size: int):
27+
super(ErnieEncoder, self).__init__(config)
28+
29+
self.ernie = ErnieModel(config)
30+
dropout = config.classifier_dropout if config.classifier_dropout is not None else 0.1
31+
32+
self.dropout = nn.Dropout(dropout)
33+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
3034
# Compatible to ERNIE-Search for adding extra linear layer
31-
self.output_emb_size = output_emb_size
3235
if output_emb_size is not None and output_emb_size > 0:
3336
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
34-
self.emb_reduce_linear = paddle.nn.Linear(
35-
self.ernie.config["hidden_size"], output_emb_size, weight_attr=weight_attr
36-
)
37+
self.emb_reduce_linear = paddle.nn.Linear(config.hidden_size, output_emb_size, weight_attr=weight_attr)
3738
self.apply(self.init_weights)
3839

3940
def init_weights(self, layer):

0 commit comments

Comments
 (0)