Skip to content

Commit be130a7

Browse files
authored
jit: add -lcuda to default ldflags (#1825)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description This PR add `-lcuda` to default ldflags instead of making it optional. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 2b753a5 commit be130a7

File tree

4 files changed

+1
-8
lines changed

4 files changed

+1
-8
lines changed

β€Žflashinfer/fused_moe/core.pyβ€Ž

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ def gen_cutlass_fused_moe_module(
387387
],
388388
extra_cuda_cflags=nvcc_flags,
389389
extra_cflags=["-DFAST_BUILD"] if use_fast_build else [],
390-
extra_ldflags=["-lcuda"],
391390
extra_include_paths=[
392391
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
393392
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
@@ -1010,7 +1009,6 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
10101009
f'-DTLLM_GEN_BMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"',
10111010
]
10121011
+ nvcc_flags,
1013-
extra_ldflags=["-lcuda"],
10141012
extra_include_paths=[
10151013
# link "include" sub-directory in cache
10161014
jit_env.FLASHINFER_CUBIN_DIR / include_path,

β€Žflashinfer/gemm.pyβ€Ž

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec:
232232
extra_cflags=[
233233
"-DFAST_BUILD",
234234
],
235-
extra_ldflags=["-lcuda"],
236235
)
237236

238237

@@ -279,7 +278,6 @@ def gen_gemm_sm120_module_cutlass_fp4() -> JitSpec:
279278
extra_cflags=[
280279
"-DFAST_BUILD",
281280
],
282-
extra_ldflags=["-lcuda"],
283281
)
284282

285283

@@ -330,7 +328,6 @@ def gen_gemm_sm100_module_cutlass_fp8() -> JitSpec:
330328
extra_cflags=[
331329
"-DFAST_BUILD",
332330
],
333-
extra_ldflags=["-lcuda"],
334331
)
335332

336333

@@ -649,7 +646,6 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
649646
+ sm100a_nvcc_flags,
650647
# link "include" sub-directory in cache
651648
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
652-
extra_ldflags=["-lcuda"],
653649
)
654650

655651

β€Žflashinfer/jit/attention/modules.pyβ€Ž

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,6 @@ def gen_trtllm_gen_fmha_module():
15861586
jit_env.FLASHINFER_CSRC_DIR / "trtllm_fmha_kernel_launcher.cu",
15871587
jit_env.FLASHINFER_CSRC_DIR / "fmhaReduction.cu",
15881588
],
1589-
extra_ldflags=["-lcuda"],
15901589
# link "include" sub-directory in cache
15911590
extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path],
15921591
extra_cuda_cflags=[
@@ -1691,7 +1690,6 @@ def gen_cudnn_fmha_module():
16911690
return gen_jit_spec(
16921691
"fmha_cudnn_gen",
16931692
[jit_env.FLASHINFER_CSRC_DIR / "cudnn_sdpa_kernel_launcher.cu"],
1694-
extra_ldflags=["-lcuda"],
16951693
extra_cuda_cflags=[
16961694
f'-DCUDNN_SDPA_CUBIN_PATH=\\"{ArtifactPath.CUDNN_SDPA}\\"',
16971695
],

β€Žflashinfer/jit/cpp_ext.pyβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def generate_ninja_build_for_op(
171171
"-L$cuda_home/lib64",
172172
"-L$cuda_home/lib64/stubs",
173173
"-lcudart",
174+
"-lcuda",
174175
]
175176

176177
env_extra_ldflags = parse_env_flags("FLASHINFER_EXTRA_LDFLAGS")

0 commit comments

Comments
Β (0)