Skip to content

Commit 944be5a

Browse files
authored
[rocm7.0_internal_testing] Prevent static initialization of at::cuda::warp_size() (#2293)
Fixes SWDEV-540240, SWDEV-540309, SWDEV-539989 ### Error ``` #24 437.7 what(): HIP error: no ROCm-capable device is detected #24 437.7 HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. #24 437.7 For debugging consider passing AMD_SERIALIZE_KERNEL=3 #24 437.7 Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers. #24 437.7 Exception raised from c10_hip_check_implementation at /pytorch/c10/hip/HIPException.cpp:44 (most recent call first): #24 437.7 frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x88 (0x7f272de18738 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so) #24 437.7 frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x55 (0x7f272ddb42ed in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so) ... #24 437.7 frame #7: at::cuda::getCurrentDeviceProperties() + 0x9 (0x7f270b5874e9 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_hip.so) #24 437.7 frame #8: at::cuda::warp_size() + 0x9 (0x7f270b587509 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_hip.so) #24 437.7 frame #9: <unknown function> + 0x81ac8b (0x7f2709c27c8b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_hip.so) ``` ### Explanation 80cca70 created a static global variable that used `at::cuda::warp_size()` to initialize its value, which needs GPUs to be visible to query device properties. However, GPUs are not present on CPU-only build systems. ### Solution Convert static variable into a static function, thus preventing static initialization. ### Validation http://rocm-ci.amd.com/job/pyt_whl_docker_mainline/1461/artifact/build_artifacts.txt/*view*/ Ran microbenchmark to confirm basic functionality: ``` root@ubb4-rack-22:/var/lib/jenkins/pytorch-micro-benchmarking# python3 micro_benchmarking_pytorch.py --network resnet50 INFO: running forward and backward for warmup. INFO: running the benchmark.. OK: finished running benchmark.. --------------------SUMMARY-------------------------- Microbenchmark for network : resnet50 Num devices: 1 Dtype: FP32 Mini batch size [img] : 64 Time per mini-batch : 0.10158218145370483 Throughput [img/sec] : 630.0317544289736= ```
1 parent 347efdf commit 944be5a

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

aten/src/ATen/native/cuda/Embedding.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
369369

370370
int warp_size = at::cuda::warp_size();
371371
TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 &&
372-
num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads,
372+
num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads(),
373373
"BlockReduceSum requires all warps be active");
374374
const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr<int64_t>();
375375
dim3 grid = unique_indices.numel();

aten/src/ATen/native/cuda/MultinomialKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ void renormRows(Tensor& t) {
8686
TORCH_CHECK(props != nullptr);
8787
int numSM = props->multiProcessorCount;
8888
const int64_t maxThreads = std::min(
89-
props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads);
89+
props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads());
9090

9191
int warp_size = at::cuda::warp_size();
9292
dim3 grid(rows < numSM * 4 ? rows : numSM * 4);

aten/src/ATen/native/cuda/TensorModeKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ void handle_fused_mode(
207207
constexpr int num_threads = size / 2;
208208
int warp_size = at::cuda::warp_size();
209209
TORCH_INTERNAL_ASSERT(num_threads % warp_size == 0 &&
210-
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, "");
210+
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads(), "");
211211
const auto memsize =
212212
(sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int));
213213
compute_mode<scalar_t, size>

aten/src/ATen/native/cuda/block_reduce.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@ constexpr int kCUDABlockReduceNumThreads = 512;
1515
// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions,
1616
// and kCUDABlockReduceMaxThreads is a host-side variable.
1717
#ifdef USE_ROCM
18-
static const int kCUDABlockReduceMaxThreads = at::cuda::warp_size() * at::cuda::warp_size();
18+
static int kCUDABlockReduceMaxThreads() {
19+
return at::cuda::warp_size() * at::cuda::warp_size();
20+
}
1921
#else
20-
constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
22+
constexpr int kCUDABlockReduceMaxThreads() {
23+
return C10_WARP_SIZE * C10_WARP_SIZE;
24+
}
2125
#endif
2226

2327
// Sums `val` across all threads in a warp.

0 commit comments

Comments
 (0)