Skip to content

Commit 313fbb6

Browse files
author
gongenlei
authored
[FT Decoder] Add JIT and bug fix (#1173)
* fix&feat: gelu-->relu and add jit * docs: update yaml
1 parent 433ef99 commit 313fbb6

File tree

6 files changed

+21
-12
lines changed

6 files changed

+21
-12
lines changed

paddlenlp/ops/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ if(WITH_ENCODER AND NOT ON_INFER)
5555
endif()
5656

5757
if(WITH_DECODER)
58-
list(APPEND decoder_op_files fusion_decoder_op.cc fusion_decoder_op.cu)
58+
list(APPEND decoding_op_files fusion_decoder_op.cc fusion_decoder_op.cu)
5959
endif()
6060

6161
if(WITH_BART)

paddlenlp/ops/faster_transformer/sample/config/decoder.sample.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,4 @@ dropout: 0.1
3636
weight_sharing: True
3737

3838
# Path of trained parameter, to make prediction
39-
init_from_params: base_trained_models/step_final
39+
init_from_params: base_trained_models/step_final

paddlenlp/ops/faster_transformer/sample/decoder_sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def parse_args():
3333
help="Path of the config file. ")
3434
parser.add_argument(
3535
"--decoder_lib",
36-
default="../../build/lib/libdecoder_op.so",
36+
default="../../build/lib/libdecoding_op.so",
3737
type=str,
38-
help="Path of libdecoder_op.so. ")
38+
help="Path of libdecoding_op.so. ")
3939
parser.add_argument(
4040
"--use_fp16_decoder",
4141
action="store_true",

paddlenlp/ops/faster_transformer/src/CMakeLists.txt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ if(ON_INFER)
145145
set(DEPS ${DEPS} shlwapi.lib)
146146
endif(NOT WIN32)
147147

148-
cuda_add_library(decoding_infer_op ${decoding_op_files} ${decoder_op_files} SHARED)
148+
cuda_add_library(decoding_infer_op ${decoding_op_files} SHARED)
149149
add_dependencies(decoding_infer_op extern_${THIRD_PARTY_NAME} boost)
150150

151151
string(REPLACE "/" ";" DEMO_PATH ${DEMO})
@@ -271,8 +271,4 @@ else(ON_INFER)
271271
add_library(decoding_op SHARED ${decoding_op_files})
272272
add_dependencies(decoding_op extern_${THIRD_PARTY_NAME} boost)
273273
target_link_libraries(decoding_op PRIVATE -lcublas -lcudart ${lib_link} ${ft_lib_link} -lencoder)
274-
275-
add_library(decoder_op SHARED ${decoder_op_files})
276-
add_dependencies(decoder_op extern_${THIRD_PARTY_NAME} boost)
277-
target_link_libraries(decoder_op PRIVATE -lcublas -lcudart -ldecoder ${lib_link})
278274
endif()

paddlenlp/ops/faster_transformer/src/fusion_decoder_op.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,13 @@ std::vector<paddle::Tensor> decoder_kernel(
7676
typedef typename traits_::data_t data_t_;
7777
typedef DecoderTransformerTraits<traits_::OpType> DecoderTraits_;
7878
OpenDecoder<DecoderTraits_::OpType>* decoder_;
79-
decoder_ = new OpenDecoder<DecoderTraits_::OpType>(
80-
batch_size_, max_seq_len_, n_head, size_per_head, memory_hidden_dim_);
79+
decoder_ = new OpenDecoder<DecoderTraits_::OpType>(batch_size_,
80+
max_seq_len_,
81+
n_head,
82+
size_per_head,
83+
memory_hidden_dim_,
84+
true,
85+
ActivationType::RELU);
8186

8287
DataType_* decoder_output = reinterpret_cast<DataType_*>(
8388
decoder_output_tensor.mutable_data<data_t_>());

paddlenlp/ops/faster_transformer/transformer/decoder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
from paddle.fluid.layer_helper import LayerHelper
2222
from paddlenlp.transformers import WordEmbedding, PositionalEmbedding, position_encoding_init
23-
23+
from paddlenlp.utils.log import logger
24+
from paddlenlp.ops.ext_utils import load
2425
from paddlenlp.ops import transfer_param
2526

2627

@@ -115,6 +116,13 @@ def __init__(self,
115116
# Maybe it has been loadad by `ext_utils.load`
116117
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
117118
decoder_lib)
119+
else:
120+
if decoder_lib is not None:
121+
logger.warning(
122+
"The specified decoder_lib does not exist, and it will be built automatically."
123+
)
124+
load("FasterTransformer", verbose=True)
125+
118126
super(InferTransformerDecoder, self).__init__()
119127
self.n_head = n_head
120128
self.size_per_head = size_per_head

0 commit comments

Comments
 (0)