Skip to content

Commit 621829b

Browse files
committed
Refactor
1 parent 54c4946 commit 621829b

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

ext/AMDGPUExt/functor.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ adapt_storage(::FluxAMDAdaptor, x::AbstractRNG) = error("""
1818

1919
adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng()
2020

21-
function ChainRulesCore.rrule(::Type{Array}, x::ROCArray)
22-
Array(x), dx -> (NoTangent(), ROCArray(unthunk(dx)))
23-
end
24-
2521
function ChainRulesCore.rrule(
2622
::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::AMDGPU.AnyROCArray,
2723
)
@@ -32,9 +28,8 @@ end
3228

3329
function _amd(x)
3430
check_use_amdgpu()
35-
USE_AMDGPU[] ?
36-
fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_isleaf) :
37-
x
31+
USE_AMDGPU[] || return x
32+
fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_isleaf)
3833
end
3934

4035
# Since MIOpen supports only cross-correlation as convolution,

test/runtests.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,18 @@ Random.seed!(0)
6363
end
6464
end
6565

66-
if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true"
67-
using AMDGPU
68-
AMDGPU.versioninfo()
69-
if AMDGPU.functional() && AMDGPU.functional(:MIOpen)
70-
@show AMDGPU.MIOpen.version()
71-
@testset "AMDGPU" begin
72-
include("amd/runtests.jl")
73-
end
74-
else
75-
@info "AMDGPU.jl package is not functional. Skipping AMDGPU tests."
76-
end
66+
if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true"
67+
using AMDGPU
68+
AMDGPU.versioninfo()
69+
if AMDGPU.functional() && AMDGPU.functional(:MIOpen)
70+
@show AMDGPU.MIOpen.version()
71+
@testset "AMDGPU" begin
72+
include("amd/runtests.jl")
73+
end
7774
else
78-
@info "Skipping AMDGPU tests, set FLUX_TEST_CUDA=true to run them."
75+
@info "AMDGPU.jl package is not functional. Skipping AMDGPU tests."
7976
end
77+
else
78+
@info "Skipping AMDGPU tests, set FLUX_TEST_AMDGPU=true to run them."
79+
end
8080
end

0 commit comments

Comments
 (0)