Skip to content

Commit 3cefe05

Browse files
cthifacebook-github-bot
authored andcommitted
Fix ROCm build I broke (#4902)
Summary: X-link: facebookresearch/FBGEMM#1929 Pull Request resolved: #4902 Broke the ROCm build by accident in D82855103 :(. Since the function is cuda specific, we can just only include it when being used for cuda for now. For some reason it cannot build properly standalone still in fbcode, which Is why I originally tried to move it to inline header, but that broke OSS since hipify would not get run. Reviewed By: q10 Differential Revision: D82895451 fbshipit-source-id: 41553cf51c8a93f72f891e206de97ff00fd108dc
1 parent 8784fab commit 3cefe05

File tree

6 files changed

+34
-17
lines changed

6 files changed

+34
-17
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize/utils.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
#include <climits>
1212
#include <cstdint>
1313

14-
#include <ATen/cuda/CUDAContext.h>
15-
1614
namespace fbgemm_gpu {
1715

1816
constexpr int64_t nextPowerOf2(int64_t num) {
@@ -21,19 +19,4 @@ constexpr int64_t nextPowerOf2(int64_t num) {
2119
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
2220
}
2321

24-
inline int getDeviceArch() {
25-
static int arch = []() {
26-
const int majorVersion =
27-
at::cuda::getDeviceProperties(at::cuda::current_device())->major;
28-
if (majorVersion >= 10) {
29-
int runtimeVersion = 0;
30-
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
31-
TORCH_CHECK(
32-
runtimeVersion >= 12080, "SM100a+ kernels require cuda >= 12.8");
33-
}
34-
return majorVersion;
35-
}();
36-
return arch;
37-
}
38-
3922
} // namespace fbgemm_gpu
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ATen/cuda/CUDAContext.h>
12+
13+
namespace fbgemm_gpu {
14+
15+
inline int getDeviceArch() {
16+
static int arch = []() {
17+
const int majorVersion =
18+
at::cuda::getDeviceProperties(at::cuda::current_device())->major;
19+
if (majorVersion >= 10) {
20+
int runtimeVersion = 0;
21+
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
22+
TORCH_CHECK(
23+
runtimeVersion >= 12080, "SM100a+ kernels require cuda >= 12.8");
24+
}
25+
return majorVersion;
26+
}();
27+
return arch;
28+
}
29+
30+
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh"
1313
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
1414
#include "fbgemm_gpu/quantize/utils.h"
15+
#include "fbgemm_gpu/quantize/utils_gpu.h"
1516

1617
namespace fbgemm_gpu {
1718

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_groupwise.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "f8f8bf16_groupwise/f8f8bf16_groupwise_manifest.cuh"
1414
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
1515
#include "fbgemm_gpu/quantize/utils.h"
16+
#include "fbgemm_gpu/quantize/utils_gpu.h"
1617

1718
namespace fbgemm_gpu {
1819

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_manifest.cuh"
1212

1313
#include "fbgemm_gpu/quantize/utils.h"
14+
#include "fbgemm_gpu/quantize/utils_gpu.h"
1415

1516
namespace fbgemm_gpu {
1617

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "f8f8bf16_rowwise_grouped_sm100/f8f8bf16_rowwise_grouped_manifest.cuh"
1515
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
1616
#include "fbgemm_gpu/quantize/utils.h"
17+
#include "fbgemm_gpu/quantize/utils_gpu.h"
1718

1819
namespace fbgemm_gpu {
1920

0 commit comments

Comments
 (0)