@@ -57,8 +57,7 @@ def strtobool(v):
57
57
58
58
def get_gencode_flags ():
59
59
if not strtobool (os .getenv ("FLAG_LLM_PDC" , "False" )):
60
- prop = paddle .device .cuda .get_device_properties ()
61
- cc = prop .major * 10 + prop .minor
60
+ cc = get_sm_version ()
62
61
return ["-gencode" , "arch=compute_{0},code=sm_{0}" .format (cc )]
63
62
else :
64
63
# support more cuda archs
@@ -75,6 +74,7 @@ def get_gencode_flags():
75
74
gencode_flags = get_gencode_flags ()
76
75
library_path = os .environ .get ("LD_LIBRARY_PATH" , "/usr/local/cuda/lib64" )
77
76
77
+ sm_version = get_sm_version ()
78
78
79
79
sources = [
80
80
"./gpu/save_with_output.cc" ,
@@ -102,16 +102,11 @@ def get_gencode_flags():
102
102
"./gpu/dequant_int8.cu" ,
103
103
"./gpu/flash_attn_bwd.cc" ,
104
104
"./gpu/tune_cublaslt_gemm.cu" ,
105
- "./gpu/append_attention.cu" ,
106
- "./gpu/append_attn/get_block_shape_and_split_kv_block.cu" ,
107
- "./gpu/append_attn/decoder_write_cache_with_rope_kernel.cu" ,
108
- "./gpu/append_attn/speculate_write_cache_with_rope_kernel.cu" ,
109
105
"./gpu/sample_kernels/top_p_sampling_reject.cu" ,
110
106
"./gpu/update_inputs_v2.cu" ,
111
107
"./gpu/set_preids_token_penalty_multi_scores.cu" ,
112
108
"./gpu/speculate_decoding_kernels/ngram_match.cc" ,
113
109
]
114
- sources += find_end_files ("./gpu/append_attn/template_instantiation" , ".cu" )
115
110
sources += find_end_files ("./gpu/speculate_decoding_kernels" , ".cu" )
116
111
117
112
nvcc_compile_args = gencode_flags
@@ -138,6 +133,14 @@ def get_gencode_flags():
138
133
if cc >= 80 :
139
134
sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu" ]
140
135
136
+ sources += [
137
+ "./gpu/append_attention.cu" ,
138
+ "./gpu/append_attn/get_block_shape_and_split_kv_block.cu" ,
139
+ "./gpu/append_attn/decoder_write_cache_with_rope_kernel.cu" ,
140
+ "./gpu/append_attn/speculate_write_cache_with_rope_kernel.cu" ,
141
+ ]
142
+ sources += find_end_files ("./gpu/append_attn/template_instantiation" , ".cu" )
143
+
141
144
if cc >= 89 and cuda_version >= 12.4 :
142
145
os .system ("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py" )
143
146
os .system ("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py" )
0 commit comments