-
-
Notifications
You must be signed in to change notification settings - Fork 127
Add AMDGPU extension #470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add AMDGPU extension #470
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
module AMDGPUExt | ||
|
||
using Adapt | ||
using AMDGPU | ||
using AMDGPU.MIOpen | ||
using ChainRulesCore | ||
using NNlib | ||
using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans | ||
using NNlib: DenseConvDims, PoolDims | ||
|
||
const MIOPENFloat = Union{Float16, Float32} | ||
|
||
const ROCBatchedAdjoint{T} = BatchedAdjoint{T, <: ROCArray{T}} | ||
const ROCBatchedTranspose{T} = BatchedTranspose{T, <: ROCArray{T}} | ||
const ROCBatchedAdjOrTrans{T} = Union{ROCBatchedAdjoint{T}, ROCBatchedTranspose{T}} | ||
const WrappedROCBatchedAdjOrTrans{T, N} = Adapt.WrappedArray{T, N, ROCBatchedAdjOrTrans{T}, ROCBatchedAdjOrTrans{T}} | ||
const AnyROCBatchedAdjOrTrans = Union{ROCBatchedAdjOrTrans, WrappedROCBatchedAdjOrTrans} | ||
|
||
function Base.convert(::Type{T}, b::AnyROCBatchedAdjOrTrans) where {T <: Array} | ||
Base.convert(T, adapt(Array, b)) | ||
end | ||
|
||
function Base.Array{T, N}(b::AnyROCBatchedAdjOrTrans) where {T, N} | ||
Array{T, N}(adapt(Array, b)) | ||
end | ||
|
||
Base.collect(b::AnyROCBatchedAdjOrTrans) = collect(adapt(Array, b)) | ||
|
||
function Base.show( | ||
io::IO, mime::MIME{Symbol("text/plain")}, x::AnyROCBatchedAdjOrTrans, | ||
) | ||
show(io, mime, adapt(Array, x)) | ||
end | ||
|
||
Base.show(io::IO, x::AnyROCBatchedAdjOrTrans) = show(io, adapt(Array, x)) | ||
|
||
Base.display(x::AnyROCBatchedAdjOrTrans) = display(adapt(Array, x)) | ||
|
||
function NNlib._batched_gemm!( | ||
::Type{<: ROCArray}, transA::Char, transB::Char, α, A, B, β, C, | ||
) | ||
AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C) | ||
end | ||
|
||
function nnlib_padding(dims) | ||
pd = NNlib.padding(dims) | ||
if !all(pd[1:2:end] .== pd[2:2:end]) | ||
@warn """ | ||
MIOpen does not support asymmetric padding, defaulting to symmetric choice: | ||
$pd -> $(pd[1:2:end]). | ||
""" maxlog=1 | ||
end | ||
pd[1:2:end] | ||
end | ||
|
||
include("conv.jl") | ||
include("pool.jl") | ||
include("softmax.jl") | ||
include("activations.jl") | ||
|
||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
for (f, op) in [ | ||
NNlib.relu => MIOpen.relu, | ||
NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6), | ||
NNlib.softplus => MIOpen.softrelu, | ||
NNlib.σ => MIOpen.sigmoid, | ||
Base.tanh => MIOpen.tanh, | ||
# TODO define for leakyrelu, elu, etc.? | ||
] | ||
@eval function Base.materialize( | ||
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat}}} | ||
) | ||
return $op(bc.args[1]) | ||
end | ||
end | ||
|
||
Base.broadcasted(::typeof(identity), x::ROCArray{T}) where {T<:MIOPENFloat} = x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
function NNlib.conv!( | ||
y::ROCArray{T, N}, x::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims, | ||
) where {T <: MIOPENFloat, N} | ||
NNlib.flipkernel(cdims) || throw(ArgumentError( | ||
"MIOpen supports only cross-correlation as its convolution implementation.")) | ||
|
||
nd = max(0, 4 - N) | ||
ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd) | ||
MIOpen.convolution!( | ||
NNlib.insert_singleton_spatial_dimension(y, nd), | ||
NNlib.insert_singleton_spatial_dimension(x, nd), | ||
NNlib.insert_singleton_spatial_dimension(w, nd); | ||
padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims), | ||
dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims)) | ||
return y | ||
end | ||
|
||
function NNlib.∇conv_data!( | ||
dx::ROCArray{T, N}, dy::ROCArray{T, N}, w::ROCArray{T, N}, cdims::DenseConvDims, | ||
) where {T <: MIOPENFloat, N} | ||
NNlib.flipkernel(cdims) || throw(ArgumentError( | ||
"MIOpen supports only cross-correlation as its convolution implementation.")) | ||
|
||
nd = max(0, 4 - N) | ||
ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd) | ||
MIOpen.∇convolution_data!( | ||
NNlib.insert_singleton_spatial_dimension(dx, nd), | ||
NNlib.insert_singleton_spatial_dimension(dy, nd), | ||
NNlib.insert_singleton_spatial_dimension(w, nd); | ||
padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims), | ||
dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims)) | ||
return dx | ||
end | ||
|
||
function NNlib.∇conv_filter!( | ||
dw::ROCArray{T, N}, x::ROCArray{T, N}, dy::ROCArray{T, N}, cdims::DenseConvDims, | ||
) where {T <: MIOPENFloat, N} | ||
NNlib.flipkernel(cdims) || throw(ArgumentError( | ||
"MIOpen supports only cross-correlation as its convolution implementation.")) | ||
|
||
nd = max(0, 4 - N) | ||
ncdims = NNlib.insert_singleton_spatial_dimension(cdims, nd) | ||
MIOpen.∇convolution_weight!( | ||
NNlib.insert_singleton_spatial_dimension(dw, nd), | ||
NNlib.insert_singleton_spatial_dimension(dy, nd), | ||
NNlib.insert_singleton_spatial_dimension(x, nd); | ||
padding=nnlib_padding(ncdims), stride=NNlib.stride(ncdims), | ||
dilation=NNlib.dilation(ncdims), groups=NNlib.groupcount(ncdims)) | ||
return dw | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
for poolname in (:maxpool, :meanpool) | ||
@eval function NNlib.$(poolname)( | ||
x::ROCArray{T, N}, pdims::PoolDims, | ||
) where {T <: MIOPENFloat, N} | ||
y = similar(x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N)) | ||
nd = max(0, 4 - N) | ||
npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd) | ||
MIOpen.$(Symbol("$(poolname)!"))( | ||
NNlib.insert_singleton_spatial_dimension(y, nd), | ||
NNlib.insert_singleton_spatial_dimension(x, nd); | ||
dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), | ||
stride=NNlib.stride(npdims), do_backward=false) | ||
return y | ||
end | ||
|
||
@eval function ChainRulesCore.rrule( | ||
::typeof(NNlib.$(poolname)), x::ROCArray{T, N}, pdims::PoolDims, | ||
) where {T <: MIOPENFloat, N} | ||
y = similar(x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N)) | ||
nd = max(0, 4 - N) | ||
npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd) | ||
|
||
# `workspace` is used in the pullback. | ||
_, workspace = MIOpen.$(Symbol("$(poolname)!"))( | ||
NNlib.insert_singleton_spatial_dimension(y, nd), | ||
NNlib.insert_singleton_spatial_dimension(x, nd); | ||
dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), | ||
stride=NNlib.stride(npdims)) | ||
|
||
function _pooling_pullback(Δ) | ||
dx = similar(x) | ||
MIOpen.$(Symbol("∇$(poolname)!"))( | ||
NNlib.insert_singleton_spatial_dimension(dx, nd), | ||
NNlib.insert_singleton_spatial_dimension(unthunk(Δ), nd), | ||
NNlib.insert_singleton_spatial_dimension(y, nd), | ||
NNlib.insert_singleton_spatial_dimension(x, nd); | ||
dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), | ||
stride=NNlib.stride(npdims), workspace) | ||
return NoTangent(), dx, NoTangent() | ||
end | ||
y, _pooling_pullback | ||
end | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
for fname in (:softmax, :logsoftmax) | ||
@eval function NNlib.$(fname)(x::ROCArray{T}; dims = 1) where T <: MIOPENFloat | ||
MIOpen.$(fname)(x; dims) | ||
end | ||
|
||
@eval function NNlib.$(Symbol("∇$(fname)"))( | ||
dy::ROCArray{T, N}, x::ROCArray{T, N}, y::ROCArray{T, N}; dims = 1, | ||
) where {T <: MIOPENFloat, N} | ||
MIOpen.$(Symbol("∇$(fname)!"))(dy, y; dims) | ||
end | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
@testset "Compare CPU & GPU" begin | ||
for (T, atol) in ((Float16, 1f-2), (Float32, 1f-5)) | ||
x = randn(T, 16) | ||
gputest(x -> NNlib.relu.(x), x; atol) | ||
gputest(x -> NNlib.relu6.(x), x; atol) | ||
gputest(x -> NNlib.softplus.(x), x; atol) | ||
gputest(x -> tanh.(x), x; atol) | ||
gputest(x -> identity.(x), x; atol) | ||
end | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
@testset "batched_mul" begin | ||
A = rand(Float32, 3, 3, 2) | ||
B = rand(Float32, 3, 3, 2) | ||
dA, dB = ROCArray.((A, B)) | ||
|
||
C = batched_mul(A, B) | ||
@test ROCArray(C) ≈ batched_mul(dA, dB) | ||
|
||
Ct = batched_mul(batched_transpose(A), B) | ||
@test ROCArray(Ct) ≈ batched_mul(batched_transpose(dA), dB) | ||
|
||
Ca = batched_mul(A, batched_adjoint(B)) | ||
@test ROCArray(Ca) ≈ batched_mul(dA, batched_adjoint(dB)) | ||
|
||
# 5-arg batched_mul! | ||
C .= pi | ||
batched_mul!(C, A, B, 2f0, 3f0) | ||
Cpi = ROCArray(similar(C)) .= pi | ||
@test ROCArray(C) ≈ batched_mul!(Cpi, dA, dB, 2f0, 3f0) | ||
|
||
# PermutedDimsArray | ||
@test ROCArray(Ct) ≈ batched_mul(PermutedDimsArray(dA, (2, 1, 3)), dB) | ||
|
||
# FIXME same but with (1, 3, 2) errors | ||
D = permutedims(B, (2, 1, 3)) | ||
Cp = batched_mul(batched_adjoint(A), B) | ||
@test ROCArray(Cp) ≈ batched_mul( | ||
batched_adjoint(dA), PermutedDimsArray(ROCArray(D), (2, 1, 3))) | ||
|
||
# Methods which reshape | ||
M = randn(Float32, 3, 3) | ||
Cm = batched_mul(A, M) | ||
@test ROCArray(Cm) ≈ batched_mul(dA, ROCArray(M)) | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
function print_array_strs(x) | ||
str = sprint((io, x)->show(io, MIME"text/plain"(), x), x) | ||
return @view split(str, '\n')[2:end] | ||
end | ||
|
||
@testset "BatchedAdjOrTrans" begin | ||
x = rand(Float32, 3, 4, 2) | ||
y = ROCArray(x) | ||
|
||
bax = batched_adjoint(x) | ||
btx = batched_transpose(x) | ||
bay = batched_adjoint(y) | ||
bty = batched_transpose(y) | ||
|
||
@test sprint(show, bax) == sprint(show, bay) | ||
@test sprint(show, btx) == sprint(show, bty) | ||
|
||
@test print_array_strs(bax) == print_array_strs(bay) | ||
@test print_array_strs(btx) == print_array_strs(bty) | ||
|
||
@test Array(bax) == Array(bay) | ||
@test collect(bax) == collect(bay) | ||
@test Array(btx) == Array(bty) | ||
@test collect(btx) == collect(bty) | ||
|
||
for shape in (:, (12, 2)) | ||
rbax = reshape(bax, shape) | ||
rbtx = reshape(btx, shape) | ||
rbay = reshape(bay, shape) | ||
rbty = reshape(bty, shape) | ||
|
||
@test sprint(show, rbax) == sprint(show, rbay) | ||
@test sprint(show, rbtx) == sprint(show, rbty) | ||
|
||
@test print_array_strs(rbax) == print_array_strs(rbay) | ||
@test print_array_strs(rbtx) == print_array_strs(rbty) | ||
|
||
@test Array(rbax) == Array(rbay) | ||
@test collect(rbax) == collect(rbay) | ||
@test Array(rbtx) == Array(rbty) | ||
@test collect(rbtx) == collect(rbty) | ||
end | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
@testset "Compare CPU & GPU" begin | ||
channels, batch = 3, 2 | ||
for T in (Float16, Float32), nd in (1, 2, 3) | ||
x = rand(Float32, fill(4, nd)..., 3, 1) | ||
w = rand(Float32, fill(2, nd)..., channels, 4) | ||
cdims = DenseConvDims(x, w, flipkernel=true) | ||
gputest((x, w) -> NNlib.conv(x, w, cdims), x, w; atol=1e-4) | ||
end | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
@testset "Compare CPU & GPU" begin | ||
channels, batch = 3, 2 | ||
for T in (Float16, Float32), nd in (1, 2, 3) | ||
x = rand(T, fill(8, nd)..., channels, batch) | ||
pdims = PoolDims(x, 2) | ||
# NOTE: Disable grad check for maxpool as *sometimes* | ||
# it does not *completely* agree with CPU :/ | ||
gputest(x -> NNlib.maxpool(x, pdims), x; checkgrad=false) | ||
gputest(x -> NNlib.meanpool(x, pdims), x) | ||
end | ||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.