Skip to content

Commit f631d3e

Browse files
authored
Delete confusing parameter to make FT JIT easy to activate (#1495)
* fix need build of faster generation * update * update
1 parent 8aa407e commit f631d3e

File tree

6 files changed

+55
-36
lines changed

6 files changed

+55
-36
lines changed

paddlenlp/ops/ext_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
# Clear it for other non-CUDA situations.
3535
CUDA_HOME = None
3636

37+
LOADED_EXT = {}
38+
3739

3840
def _get_files(path):
3941
"""
@@ -221,6 +223,8 @@ def load(name, build_dir=None, force=False, verbose=False, **kwargs):
221223
logger.warning("%s is not available because CUDA can not be found." %
222224
name)
223225
raise NotImplementedError
226+
if name in LOADED_EXT.keys():
227+
return LOADED_EXT[name]
224228
if build_dir is None:
225229
# Maybe under package dir is better to avoid cmake source path conflict
226230
# with different source path.
@@ -247,7 +251,9 @@ def load(name, build_dir=None, force=False, verbose=False, **kwargs):
247251
ext_sources, ext_filepath, 'newer'):
248252
logger.debug("skipping '%s' extension (up-to-date) build" %
249253
name)
250-
return load_op_meta_info_and_register_op(ext_filepath)
254+
ops = load_op_meta_info_and_register_op(ext_filepath)
255+
LOADED_EXT[name] = ops
256+
return LOADED_EXT[name]
251257

252258
# write setup file and jit compile
253259
file_path = os.path.join(build_dir, "{}_setup.py".format(name))
@@ -256,7 +262,9 @@ def load(name, build_dir=None, force=False, verbose=False, **kwargs):
256262
if isinstance(extension, CMakeExtension):
257263
# Load a shared library (if exists) only to register op.
258264
if os.path.exists(ext_filepath):
259-
load_op_meta_info_and_register_op(ext_filepath)
265+
ops = load_op_meta_info_and_register_op(ext_filepath)
266+
LOADED_EXT[name] = ops
267+
return LOADED_EXT[name]
260268
else:
261269
# Import as callable python api
262270
return _import_module_from_library(name, build_base_dir, verbose)

paddlenlp/ops/faster_transformer/sample/encoder_decoding_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def do_predict(args):
110110

111111
if args.enable_faster_encoder:
112112
transformer = enable_faster_encoder(
113-
transformer, need_build=False, use_fp16=args.use_fp16_encoder)
113+
transformer, use_fp16=args.use_fp16_encoder)
114114

115115
src_word = generate_src_word(
116116
batch_size=args.infer_batch_size,

paddlenlp/ops/faster_transformer/transformer/decoder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from paddle.fluid.layer_helper import LayerHelper
2222
from paddlenlp.transformers import WordEmbedding, PositionalEmbedding, position_encoding_init
2323
from paddlenlp.utils.log import logger
24-
from paddlenlp.ops.ext_utils import load
24+
from paddlenlp.ops.ext_utils import load, LOADED_EXT
2525
from paddlenlp.ops import transfer_param
2626

2727

@@ -160,8 +160,10 @@ def __init__(self,
160160

161161
if decoder_lib is not None and os.path.isfile(decoder_lib):
162162
# Maybe it has been loadad by `ext_utils.load`
163-
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
164-
decoder_lib)
163+
if "FasterTransformer" not in LOADED_EXT.keys():
164+
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
165+
decoder_lib)
166+
LOADED_EXT["FasterTransformer"] = ops
165167
else:
166168
if decoder_lib is not None:
167169
logger.warning(

paddlenlp/ops/faster_transformer/transformer/decoding.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from paddle.fluid.layer_helper import LayerHelper
2424
import paddle
2525

26-
from paddlenlp.ops.ext_utils import load
26+
from paddlenlp.ops.ext_utils import load, LOADED_EXT
2727
from paddlenlp.utils.log import logger
2828

2929

@@ -610,8 +610,10 @@ def __init__(self,
610610
# raise ValueError("The path to decoding lib is not exist.")
611611
if decoding_lib is not None and os.path.isfile(decoding_lib):
612612
# Maybe it has been loadad by `ext_utils.load`
613-
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
614-
decoding_lib)
613+
if "FasterTransformer" not in LOADED_EXT.keys():
614+
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
615+
decoding_lib)
616+
LOADED_EXT["FasterTransformer"] = ops
615617
else:
616618
if decoding_lib is not None:
617619
logger.warning(
@@ -870,8 +872,10 @@ def parse_function(func_name):
870872
class InferGptDecoding(nn.Layer):
871873
def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
872874
if decoding_lib is not None and os.path.isfile(decoding_lib):
873-
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
874-
decoding_lib)
875+
if "FasterTransformer" not in LOADED_EXT.keys():
876+
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
877+
decoding_lib)
878+
LOADED_EXT["FasterTransformer"] = ops
875879
else:
876880
if decoding_lib is not None:
877881
logger.warning(
@@ -1078,8 +1082,10 @@ def __init__(self,
10781082
hidden_act="gelu"):
10791083
if decoding_lib is not None and os.path.isfile(decoding_lib):
10801084
# Maybe it has been loadad by `ext_utils.load`
1081-
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
1082-
decoding_lib)
1085+
if "FasterTransformer" not in LOADED_EXT.keys():
1086+
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
1087+
decoding_lib)
1088+
LOADED_EXT["FasterTransformer"] = ops
10831089
else:
10841090
if decoding_lib is not None:
10851091
logger.warning(
@@ -1442,8 +1448,10 @@ class InferBartDecoding(nn.Layer):
14421448
def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
14431449
if decoding_lib is not None and os.path.isfile(decoding_lib):
14441450
# Maybe it has been loadad by `ext_utils.load`
1445-
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
1446-
decoding_lib)
1451+
if "FasterTransformer" not in LOADED_EXT.keys():
1452+
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
1453+
decoding_lib)
1454+
LOADED_EXT["FasterTransformer"] = ops
14471455
else:
14481456
if decoding_lib is not None:
14491457
logger.warning(
@@ -1683,8 +1691,10 @@ def __init__(self,
16831691
hidden_act="gelu"):
16841692
if decoding_lib is not None and os.path.isfile(decoding_lib):
16851693
# Maybe it has been loadad by `ext_utils.load`
1686-
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
1687-
decoding_lib)
1694+
if "FasterTransformer" not in LOADED_EXT.keys():
1695+
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
1696+
decoding_lib)
1697+
LOADED_EXT["FasterTransformer"] = ops
16881698
else:
16891699
if decoding_lib is not None:
16901700
logger.warning(

paddlenlp/ops/faster_transformer/transformer/encoder.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,7 @@ def encoder_forward(self, src, src_mask=None, cache=None):
238238
return output
239239

240240

241-
def enable_faster_encoder(self,
242-
need_build=True,
243-
use_fp16=False,
244-
encoder_lib=None):
241+
def enable_faster_encoder(self, use_fp16=False, encoder_lib=None):
245242
"""
246243
Compiles fusion encoder operator intergrated FasterTransformer using the
247244
method of JIT(Just-In-Time) and replaces the `forward` function of
@@ -281,19 +278,21 @@ def init_func(layer):
281278
convert_to_fp16(layer)
282279

283280
if not self.training:
284-
if need_build:
285-
try:
286-
# Pass decoding lib to prevent re-building encoder.
287-
# Todo: check weather decoding lib have contained encoder or not.
288-
if encoder_lib is not None:
289-
load_op_meta_info_and_register_op(encoder_lib)
290-
else:
291-
load("FasterTransformer", verbose=True)
292-
except Exception:
293-
logger.warning(
294-
"Exception occurs when using FasterEncoder. " \
295-
"The original forward will be involved. ")
296-
return self
281+
try:
282+
# Pass decoding lib to prevent re-building encoder.
283+
# Todo: check weather decoding lib have contained encoder or not.
284+
if encoder_lib is not None:
285+
if "FasterTransformer" not in LOADED_EXT.keys():
286+
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
287+
decoding_lib)
288+
LOADED_EXT["FasterTransformer"] = ops
289+
else:
290+
load("FasterTransformer", verbose=True)
291+
except Exception:
292+
logger.warning(
293+
"Exception occurs when using FasterEncoder. " \
294+
"The original forward will be involved. ")
295+
return self
297296
for layer in self.children():
298297
layer.apply(init_func)
299298
return self

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,7 @@ def forward(self,
11611161
**model_kwargs):
11621162

11631163
if encoder_output is None:
1164-
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
1164+
self.encoder = enable_faster_encoder(self.encoder)
11651165
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
11661166
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
11671167
input_ids, model_kwargs)["encoder_output"]
@@ -1265,7 +1265,7 @@ def forward(self,
12651265

12661266
#(gongenlei) Not enable_faster_encoder temporarily
12671267
if encoder_output is None:
1268-
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
1268+
self.encoder = enable_faster_encoder(self.encoder)
12691269
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
12701270
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
12711271
input_ids, model_kwargs)["encoder_output"]

0 commit comments

Comments
 (0)