Skip to content

Commit 210347c

Browse files
cthifacebook-github-bot
authored andcommitted
Move cudaGetDeviceProperties code to util (#4838)
Summary: Pull Request resolved: #4838 X-link: facebookresearch/FBGEMM#1864 We should move this into it's own function and reuse it instead of copy paste. Also return the actual arch from `prop.major`directly. Reviewed By: q10 Differential Revision: D81963785 fbshipit-source-id: d65cadeb65300c2ddfcc45f40afa9053f6763308
1 parent 2033a0a commit 210347c

File tree

6 files changed

+42
-66
lines changed

6 files changed

+42
-66
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ constexpr int64_t nextPowerOf2(int64_t num) {
1919
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
2020
}
2121

22+
int getDeviceArch();
23+
2224
} // namespace fbgemm_gpu
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fbgemm_gpu/quantize/utils.h" // @manual
10+
11+
#include <ATen/ATen.h>
12+
#include <c10/cuda/CUDAException.h>
13+
#include <cuda_runtime.h>
14+
15+
namespace fbgemm_gpu {
16+
17+
int getDeviceArch() {
18+
static int arch = []() {
19+
// Avoid expensive cudaGetDeviceProperties call.
20+
cudaDeviceProp prop;
21+
cudaGetDeviceProperties(&prop, 0);
22+
23+
if (prop.major >= 10) {
24+
int runtimeVersion = 0;
25+
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
26+
TORCH_CHECK(
27+
runtimeVersion >= 12080, "SM100a+ kernels require cuda >= 12.8");
28+
}
29+
30+
return prop.major;
31+
}();
32+
return arch;
33+
}
34+
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -207,22 +207,7 @@ at::Tensor dispatch_bf16_grouped_kernel(
207207
at::Tensor output,
208208
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
209209
std::optional<at::Tensor> M_sizes = std::nullopt) {
210-
static int arch = -1;
211-
// Avoid expensive cudaGetDeviceProperties call.
212-
if (arch < 0) {
213-
cudaDeviceProp prop;
214-
cudaGetDeviceProperties(&prop, 0);
215-
if (prop.major >= 10) {
216-
arch = 10;
217-
int runtimeVersion;
218-
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
219-
TORCH_CHECK(
220-
runtimeVersion >= 12080,
221-
"FP8 grouped GEMM on sm100a or above requires cuda >= 12.8");
222-
} else {
223-
arch = 9;
224-
}
225-
}
210+
const int arch = getDeviceArch();
226211

227212
// Select kernel to run via heuristics or tuning.
228213
auto kernel = [&]() {

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_groupwise.cu

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
1111
#include <c10/cuda/CUDAGuard.h>
12-
// clang-format on
1312

1413
#include "f8f8bf16_groupwise/f8f8bf16_groupwise_manifest.cuh"
1514
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
@@ -64,22 +63,7 @@ at::Tensor dispatch_fp8_groupwise_kernel(
6463
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
6564
int K = XQ.size(-1);
6665

67-
static int arch = -1;
68-
// Avoid expensive cudaGetDeviceProperties call.
69-
if (arch < 0) {
70-
cudaDeviceProp prop;
71-
cudaGetDeviceProperties(&prop, 0);
72-
if (prop.major >= 10) {
73-
arch = 10;
74-
int runtimeVersion;
75-
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
76-
TORCH_CHECK(
77-
runtimeVersion >= 12080,
78-
"FP8 GEMM on sm100a or above requires cuda >= 12.8");
79-
} else {
80-
arch = 9;
81-
}
82-
}
66+
const int arch = getDeviceArch();
8367

8468
// Select kernel to run via heuristics or tuning.
8569
auto kernel = [&]() {

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <cute/tensor.hpp>
1111
#include "f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_manifest.cuh"
1212

13+
#include "fbgemm_gpu/quantize/utils.h"
14+
1315
namespace fbgemm_gpu {
1416

1517
#if CUDART_VERSION >= 12000
@@ -30,22 +32,7 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
3032
bool use_fast_accum = true,
3133
std::optional<at::Tensor> bias = std::nullopt,
3234
std::optional<at::Tensor> output = std::nullopt) {
33-
static int arch = -1;
34-
// Avoid expensive cudaGetDeviceProperties call.
35-
if (arch < 0) {
36-
cudaDeviceProp prop;
37-
cudaGetDeviceProperties(&prop, 0);
38-
if (prop.major >= 10) {
39-
arch = 10;
40-
int runtimeVersion;
41-
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
42-
TORCH_CHECK(
43-
runtimeVersion >= 12080,
44-
"FP8 batched GEMM on sm100a or above requires cuda >= 12.8");
45-
} else {
46-
arch = 9;
47-
}
48-
}
35+
const int arch = getDeviceArch();
4936

5037
TORCH_CHECK(
5138
(XQ.dim() == 3 && WQ.dim() == 3),

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
1111
#include <c10/cuda/CUDAGuard.h>
12-
// clang-format on
1312

1413
#include "f8f8bf16_rowwise_grouped/f8f8bf16_rowwise_grouped_manifest.cuh"
1514
#include "f8f8bf16_rowwise_grouped_sm100/f8f8bf16_rowwise_grouped_manifest.cuh"
@@ -32,22 +31,7 @@ TuningCache& getTuningCache() {
3231
template <typename InputType>
3332
Kernel_f8f8bf16_rowwise_grouped<InputType>
3433
get_kernel_via_heuristics(int total_M, int max_N, int max_K, int G) {
35-
static int arch = -1;
36-
// Avoid expensive cudaGetDeviceProperties call.
37-
if (arch < 0) {
38-
cudaDeviceProp prop;
39-
cudaGetDeviceProperties(&prop, 0);
40-
if (prop.major >= 10) {
41-
arch = 10;
42-
int runtimeVersion;
43-
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
44-
TORCH_CHECK(
45-
runtimeVersion >= 12080,
46-
"FP8 grouped GEMM on sm100a or above requires cuda >= 12.8");
47-
} else {
48-
arch = 9;
49-
}
50-
}
34+
const int arch = getDeviceArch();
5135

5236
// Use heuristics to pick the best kernel implementation.
5337
if (arch == 10) {

0 commit comments

Comments
 (0)