Skip to content

Commit 0b6284e

Browse files
authored
[Inference] Update DygraphInferencePredictor (#9491)
* update DygraphInferencePredictor * update batch_size
1 parent d455181 commit 0b6284e

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

csrc/gpu/quant_int8.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ __forceinline__ __device__ hip_bfloat16 add_mul<hip_bfloat16>(hip_bfloat16 a, hi
6565
#else
6666
template<>
6767
__forceinline__ __device__ __nv_bfloat16 add_mul<__nv_bfloat16>(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
68+
#if __CUDA_ARCH__ >= 800
6869
return __hmul(__hadd(a, b), c);
70+
#else
71+
return (static_cast<float>(a) + static_cast<float>(b)) * static_cast<float>(c);
72+
#endif
6973
}
7074
#endif
7175

csrc/setup_cuda.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ def strtobool(v):
5757

5858
def get_gencode_flags():
5959
if not strtobool(os.getenv("FLAG_LLM_PDC", "False")):
60-
prop = paddle.device.cuda.get_device_properties()
61-
cc = prop.major * 10 + prop.minor
60+
cc = get_sm_version()
6261
return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
6362
else:
6463
# support more cuda archs
@@ -75,6 +74,7 @@ def get_gencode_flags():
7574
gencode_flags = get_gencode_flags()
7675
library_path = os.environ.get("LD_LIBRARY_PATH", "/usr/local/cuda/lib64")
7776

77+
sm_version = get_sm_version()
7878

7979
sources = [
8080
"./gpu/save_with_output.cc",
@@ -102,16 +102,11 @@ def get_gencode_flags():
102102
"./gpu/dequant_int8.cu",
103103
"./gpu/flash_attn_bwd.cc",
104104
"./gpu/tune_cublaslt_gemm.cu",
105-
"./gpu/append_attention.cu",
106-
"./gpu/append_attn/get_block_shape_and_split_kv_block.cu",
107-
"./gpu/append_attn/decoder_write_cache_with_rope_kernel.cu",
108-
"./gpu/append_attn/speculate_write_cache_with_rope_kernel.cu",
109105
"./gpu/sample_kernels/top_p_sampling_reject.cu",
110106
"./gpu/update_inputs_v2.cu",
111107
"./gpu/set_preids_token_penalty_multi_scores.cu",
112108
"./gpu/speculate_decoding_kernels/ngram_match.cc",
113109
]
114-
sources += find_end_files("./gpu/append_attn/template_instantiation", ".cu")
115110
sources += find_end_files("./gpu/speculate_decoding_kernels", ".cu")
116111

117112
nvcc_compile_args = gencode_flags
@@ -138,6 +133,14 @@ def get_gencode_flags():
138133
if cc >= 80:
139134
sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"]
140135

136+
sources += [
137+
"./gpu/append_attention.cu",
138+
"./gpu/append_attn/get_block_shape_and_split_kv_block.cu",
139+
"./gpu/append_attn/decoder_write_cache_with_rope_kernel.cu",
140+
"./gpu/append_attn/speculate_write_cache_with_rope_kernel.cu",
141+
]
142+
sources += find_end_files("./gpu/append_attn/template_instantiation", ".cu")
143+
141144
if cc >= 89 and cuda_version >= 12.4:
142145
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
143146
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")

llm/predict/predictor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -733,10 +733,9 @@ def _infer(self, inputs: dict[str, paddle.Tensor]):
733733
inputs[key] = paddle.to_tensor(inputs[key])
734734

735735
inputs["cache_kvs"] = self.cache_kvs
736-
self.model.generate(
736+
return self.model.generate(
737737
**inputs,
738738
)
739-
return None
740739

741740

742741
class BlockInferencePredictorMixin(BasePredictor):
@@ -914,6 +913,12 @@ def init_model_inputs(self, config: PredictorArgument):
914913
self.model_inputs["rope_emb"] = paddle.concat([src_mask.reshape([-1]), tgt_mask.reshape([-1])])
915914

916915
def _preprocess(self, input_text: list[str]):
916+
len_input_text = len(input_text)
917+
if len_input_text < self.batch_size:
918+
padding_len = self.batch_size - len_input_text
919+
input_text += [""] * padding_len
920+
assert len(input_text) == self.batch_size
921+
917922
if self.tokenizer.chat_template is not None:
918923
input_text = [input_text] if isinstance(input_text, str) else input_text
919924
input_text = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in input_text]
@@ -1073,7 +1078,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
10731078
if self.tensor_parallel_rank == 0:
10741079
outputs = []
10751080
output_tokens = []
1076-
while len(outputs) < self.batch_size:
1081+
while len(outputs) < len(input_texts):
10771082
result = result_queue.get(timeout=1)
10781083
outputs.append(result[-1])
10791084
output_tokens.append(result[-2])

0 commit comments

Comments
 (0)