Skip to content

Commit e3df3f6

Browse files
reintegrate has_cudnn check (#2200)
* reintegrate has_cudnn check * typo
1 parent 484796c commit e3df3f6

File tree

3 files changed

+7
-0
lines changed

3 files changed

+7
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2323
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2424
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2525
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
26+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2627

2728
[compat]
2829
AMDGPU = "0.4.8"
@@ -42,6 +43,7 @@ Reexport = "0.2, 1.0"
4243
SpecialFunctions = "1.8.2, 2.1.2"
4344
StatsBase = "0.33"
4445
Zygote = "0.6.49"
46+
cuDNN = "1"
4547
julia = "1.6"
4648

4749
[extensions]

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ using .Train
4444
using .Train: setup
4545

4646
using CUDA
47+
import cuDNN
4748
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)
4849

4950
using Adapt, Functors, OneHotArrays

src/functor.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,16 @@ end
253253
function check_use_cuda()
254254
if use_cuda[] === nothing
255255
use_cuda[] = CUDA.functional()
256+
if use_cuda[] && !cuDNN.has_cudnn()
257+
@warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available." maxlog=1
258+
end
256259
if !(use_cuda[])
257260
@info """The GPU function is being called but the GPU is not accessible.
258261
Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog=1
259262
end
260263
end
261264
end
265+
262266
ChainRulesCore.@non_differentiable check_use_cuda()
263267

264268
# Precision

0 commit comments

Comments
 (0)