Skip to content

Commit 5811a8d

Browse files
pytorchboteqy
andauthored
[cuDNN][SDPA][Convolution] Expose cuDNN runtime version in CUDA hooks (pytorch#167327)
[cuDNN][SDPA][Convolution] Expose cuDNN runtime version in CUDA hooks (pytorch#167111) cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users to resolve pytorch#166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version. Pull Request resolved: pytorch#167111 Approved by: https://github.com/Skylion007 (cherry picked from commit e678450) Co-authored-by: Eddie Yan <[email protected]>
1 parent f36c764 commit 5811a8d

File tree

7 files changed

+43
-9
lines changed

7 files changed

+43
-9
lines changed

aten/src/ATen/Context.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ class TORCH_API Context {
155155
static long versionCuDNN() {
156156
return detail::getCUDAHooks().versionCuDNN();
157157
}
158+
static long versionRuntimeCuDNN() {
159+
return detail::getCUDAHooks().versionRuntimeCuDNN();
160+
}
161+
static long versionCuDNNFrontend() {
162+
return detail::getCUDAHooks().versionCuDNNFrontend();
163+
}
158164
static bool hasCuSOLVER() {
159165
return detail::getCUDAHooks().hasCuSOLVER();
160166
}

aten/src/ATen/cuda/detail/CUDAHooks.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#if AT_CUDNN_ENABLED()
2323
#include <ATen/cudnn/cudnn-wrapper.h>
24+
#include <cudnn_frontend.h>
2425
#endif
2526

2627
#if AT_MAGMA_ENABLED()
@@ -325,6 +326,26 @@ long CUDAHooks::versionCuDNN() const {
325326
#endif
326327
}
327328

329+
long CUDAHooks::versionRuntimeCuDNN() const {
330+
#if AT_CUDNN_ENABLED()
331+
#ifndef USE_STATIC_CUDNN
332+
return cudnnGetVersion();
333+
#else
334+
return CUDNN_VERSION;
335+
#endif
336+
#else
337+
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
338+
#endif
339+
}
340+
341+
long CUDAHooks::versionCuDNNFrontend() const {
342+
#if AT_CUDNN_ENABLED()
343+
return CUDNN_FRONTEND_VERSION;
344+
#else
345+
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
346+
#endif
347+
}
348+
328349
long CUDAHooks::versionMIOpen() const {
329350
#if AT_ROCM_ENABLED()
330351
return MIOPEN_VERSION_MAJOR * 10000 +

aten/src/ATen/cuda/detail/CUDAHooks.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
4848
bool hasCUDART() const override;
4949
long versionCUDART() const override;
5050
long versionCuDNN() const override;
51+
long versionRuntimeCuDNN() const override;
52+
long versionCuDNNFrontend() const override;
5153
long versionMIOpen() const override;
5254
std::string showConfig() const override;
5355
double batchnormMinEpsilonCuDNN() const override;

aten/src/ATen/detail/CUDAHooksInterface.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
170170
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
171171
}
172172

173+
virtual long versionRuntimeCuDNN() const {
174+
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
175+
}
176+
177+
virtual long versionCuDNNFrontend() const {
178+
TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
179+
}
180+
173181
virtual long versionMIOpen() const {
174182
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
175183
}

aten/src/ATen/native/Convolution.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ struct ConvParams {
413413
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
414414
return false;
415415
}
416-
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
416+
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
417417
// broken on cuDNN 9.8 - 9.14
418418
if (cudnn_version >= 90800 && cudnn_version < 91500) {
419419
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
@@ -457,7 +457,7 @@ struct ConvParams {
457457
}
458458
// native kernel doesn't support 64-bit non-splittable case
459459
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
460-
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
460+
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
461461
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
462462
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
463463
if (cudnn_version < 0 || cudnn_version > 91000) {

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
437437
const auto s_k = params.key.sym_size(2);
438438
const auto d_qk = params.query.sym_size(3);
439439
const auto d_v = params.value.sym_size(3);
440-
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
440+
long cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
441441
if (cudnn_version < 8903) {
442442
if (debug) {
443443
TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher");
@@ -668,7 +668,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
668668
return false;
669669
#endif
670670
#if defined(CUDNN_VERSION)
671-
static auto cudnn_version = cudnnGetVersion();
671+
static auto cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
672672
if (params.dropout > 0.0 && cudnn_version > 91100 && cudnn_version < 91400) {
673673
if (debug) {
674674
TORCH_WARN(CUDNN_VERSION, " cuDNN version does not support droppout in SDPA (9.11 - 9.13).");

torch/csrc/cuda/shared/cudnn.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// This file should only be compiled if this condition holds, so it should be
33
// safe.
44
#if defined(USE_CUDNN) || defined(USE_ROCM)
5+
#include <ATen/detail/CUDAHooksInterface.h>
56
#include <torch/csrc/utils/pybind.h>
67

78
#include <tuple>
@@ -32,11 +33,7 @@ version_tuple getRuntimeVersion() {
3233
}
3334

3435
size_t getVersionInt() {
35-
#ifndef USE_STATIC_CUDNN
36-
return cudnnGetVersion();
37-
#else
38-
return CUDNN_VERSION;
39-
#endif
36+
return at::detail::getCUDAHooks().versionRuntimeCuDNN();
4037
}
4138

4239
} // namespace

0 commit comments

Comments
 (0)