Skip to content

Commit b4697f0

Browse files
authored
Fix faster fp16 (#4423)
1 parent 105f572 commit b4697f0

File tree

4 files changed

+17
-5
lines changed

4 files changed

+17
-5
lines changed

model_zoo/gpt/fast_gpt/export_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def do_predict(args):
5959
model_class, tokenizer_class = MODEL_CLASSES[args.model_name_or_path]
6060
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
6161
logger.info("Loading the model parameters, please wait...")
62-
model = model_class.from_pretrained(args.model_name_or_path, max_predict_len=args.max_out_len)
62+
model = model_class.from_pretrained(args.model_name_or_path)
6363

6464
gpt = FasterGPT(model=model, decoding_lib=args.decoding_lib, use_fp16_decoding=args.use_fp16_decoding)
6565

paddlenlp/ops/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ project(FasterTransformer LANGUAGES C CXX CUDA)
1616

1717
find_package(CUDA 10.1 REQUIRED)
1818

19+
find_program(CCACHE_PROGRAM ccache)
20+
if(CCACHE_PROGRAM)
21+
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
22+
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache)
23+
endif()
24+
1925
INCLUDE(ExternalProject)
2026

2127
set(CXX_STD "14" CACHE STRING "C++ standard")

paddlenlp/ops/fast_transformer/transformer/fast_transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def save_resources(self, tokenizer, path):
812812

813813
class FasterGPT(GPTPretrainedModel):
814814
def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
815-
super(FasterGPT, self).__init__()
815+
super(FasterGPT, self).__init__(model.config)
816816
self._model = model
817817
self.use_fp16_decoding = use_fp16_decoding
818818
self.decoding = InferGptDecoding(model=model, decoding_lib=decoding_lib, use_fp16_decoding=use_fp16_decoding)
@@ -1923,9 +1923,9 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
19231923
self.use_fp16_decoding = use_fp16_decoding
19241924
self._model = model
19251925
if use_fp16_decoding:
1926-
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Assign(model.mbart.encoder.embed_tokens.weight))
1927-
model.mbart.encoder.embed_tokens = nn.Embedding(
1928-
*model.mbart.encoder.embed_tokens.weight.shape, weight_attr=weight_attr
1926+
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Assign(model.encoder.embed_tokens.weight))
1927+
model.encoder.embed_tokens = nn.Embedding(
1928+
*model.encoder.embed_tokens.weight.shape, weight_attr=weight_attr
19291929
)
19301930
self.encoder = model.t5.get_encoder()
19311931
self.decoder = model.t5.get_decoder()

paddlenlp/ops/patches/FasterTransformer/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ project(FasterTransformer LANGUAGES CXX CUDA)
1717

1818
find_package(CUDA 10.1 REQUIRED)
1919

20+
find_program(CCACHE_PROGRAM ccache)
21+
if(CCACHE_PROGRAM)
22+
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
23+
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache)
24+
endif()
25+
2026
option(BUILD_PD "Build in PaddlePaddle mode" ON)
2127
option(BUILD_GPT "Build project with gpt" ON)
2228
option(BUILD_ENCODER "Build project with encoder" ON)

0 commit comments

Comments
 (0)