Skip to content

Commit 3f99f18

Browse files
authored
perfix: use lightweight API to query device property (#1298)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 6bb969b commit 3f99f18

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,17 @@ at::Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
5656
bool const use_routing_scales_on_input, int64_t const tile_tokens_dim,
5757
int64_t const routing_method_type) {
5858
auto device = hidden_states.device();
59-
cudaDeviceProp prop;
60-
cudaGetDeviceProperties(&prop, device.index());
61-
TORCH_CHECK(prop.major == 10 && prop.minor == 0,
62-
"This kernel requires SM 100 architecture. Current device has SM ", prop.major,
63-
prop.minor, " (", prop.name, ")");
59+
60+
static const std::tuple<int, int> device_props = [&device] {
61+
int major, minor;
62+
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device.index());
63+
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device.index());
64+
return std::make_tuple(major, minor);
65+
}();
66+
67+
TORCH_CHECK(std::get<0>(device_props) == 10 && std::get<1>(device_props) == 0,
68+
"This kernel requires SM 100 architecture. Current device has SM ",
69+
std::get<0>(device_props), std::get<1>(device_props));
6470

6571
if (use_routing_scales_on_input) {
6672
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::BFloat16,
@@ -313,11 +319,17 @@ at::Tensor trtllm_fp8_block_scale_moe_launcher(
313319
tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner& moe_runner,
314320
int64_t moeConfigIndex) {
315321
auto device = hidden_states.device();
316-
cudaDeviceProp prop;
317-
cudaGetDeviceProperties(&prop, device.index());
318-
TORCH_CHECK(prop.major == 10 && prop.minor == 0,
319-
"This kernel requires SM 100 architecture. Current device has SM ", prop.major,
320-
prop.minor, " (", prop.name, ")");
322+
323+
static const std::tuple<int, int> device_props = [&device] {
324+
int major, minor;
325+
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device.index());
326+
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device.index());
327+
return std::make_tuple(major, minor);
328+
}();
329+
330+
TORCH_CHECK(std::get<0>(device_props) == 10 && std::get<1>(device_props) == 0,
331+
"This kernel requires SM 100 architecture. Current device has SM ",
332+
std::get<0>(device_props), std::get<1>(device_props));
321333

322334
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float,
323335
"routing_logits must be float.");
@@ -593,11 +605,17 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe_launcher(
593605
tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner& moe_runner,
594606
int64_t const moeConfigIndex) {
595607
auto device = hidden_states.device();
596-
cudaDeviceProp prop;
597-
cudaGetDeviceProperties(&prop, device.index());
598-
TORCH_CHECK(prop.major == 10 && prop.minor == 0,
599-
"This kernel requires SM 100 architecture. Current device has SM ", prop.major,
600-
prop.minor, " (", prop.name, ")");
608+
609+
static const std::tuple<int, int> device_props = [&device] {
610+
int major, minor;
611+
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device.index());
612+
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device.index());
613+
return std::make_tuple(major, minor);
614+
}();
615+
616+
TORCH_CHECK(std::get<0>(device_props) == 10 && std::get<1>(device_props) == 0,
617+
"This kernel requires SM 100 architecture. Current device has SM ",
618+
std::get<0>(device_props), std::get<1>(device_props));
601619

602620
TORCH_CHECK(tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 ||
603621
tile_tokens_dim == 64,

0 commit comments

Comments
 (0)