@@ -57,8 +57,7 @@ def strtobool(v):
5757
5858def get_gencode_flags ():
5959 if not strtobool (os .getenv ("FLAG_LLM_PDC" , "False" )):
60- prop = paddle .device .cuda .get_device_properties ()
61- cc = prop .major * 10 + prop .minor
60+ cc = get_sm_version ()
6261 return ["-gencode" , "arch=compute_{0},code=sm_{0}" .format (cc )]
6362 else :
6463 # support more cuda archs
@@ -75,6 +74,7 @@ def get_gencode_flags():
7574gencode_flags = get_gencode_flags ()
7675library_path = os .environ .get ("LD_LIBRARY_PATH" , "/usr/local/cuda/lib64" )
7776
77+ sm_version = get_sm_version ()
7878
7979sources = [
8080 "./gpu/save_with_output.cc" ,
@@ -102,16 +102,11 @@ def get_gencode_flags():
102102 "./gpu/dequant_int8.cu" ,
103103 "./gpu/flash_attn_bwd.cc" ,
104104 "./gpu/tune_cublaslt_gemm.cu" ,
105- "./gpu/append_attention.cu" ,
106- "./gpu/append_attn/get_block_shape_and_split_kv_block.cu" ,
107- "./gpu/append_attn/decoder_write_cache_with_rope_kernel.cu" ,
108- "./gpu/append_attn/speculate_write_cache_with_rope_kernel.cu" ,
109105 "./gpu/sample_kernels/top_p_sampling_reject.cu" ,
110106 "./gpu/update_inputs_v2.cu" ,
111107 "./gpu/set_preids_token_penalty_multi_scores.cu" ,
112108 "./gpu/speculate_decoding_kernels/ngram_match.cc" ,
113109]
114- sources += find_end_files ("./gpu/append_attn/template_instantiation" , ".cu" )
115110sources += find_end_files ("./gpu/speculate_decoding_kernels" , ".cu" )
116111
117112nvcc_compile_args = gencode_flags
@@ -138,6 +133,14 @@ def get_gencode_flags():
138133if cc >= 80 :
139134 sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu" ]
140135
136+ sources += [
137+ "./gpu/append_attention.cu" ,
138+ "./gpu/append_attn/get_block_shape_and_split_kv_block.cu" ,
139+ "./gpu/append_attn/decoder_write_cache_with_rope_kernel.cu" ,
140+ "./gpu/append_attn/speculate_write_cache_with_rope_kernel.cu" ,
141+ ]
142+ sources += find_end_files ("./gpu/append_attn/template_instantiation" , ".cu" )
143+
141144if cc >= 89 and cuda_version >= 12.4 :
142145 os .system ("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py" )
143146 os .system ("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py" )
0 commit comments