-
Notifications
You must be signed in to change notification settings - Fork 40
Open
Description
Will Bijectors support transformations on CuArrays?
Maybe a subset of transformations?
Currently, the MWE that applies a Stacked transformation on a CuArray works, but computing the logabsdet of the Jacobian fails:
#]activate --temp
#add Bijectors, Zygote, CUDA
using Bijectors
using Zygote
x = [-2, 3.0, 2.1, 2.2]
ranges = [1:1, 2:2, 3:4]
tr = Stacked((elementwise(identity), elementwise(exp), elementwise(exp)), ranges)
using CUDA
xg = CuArray(x)
y = tr(xg)
gr = Zygote.jacobian(tr, xg) # works nicely
y, logjac = Bijectors.with_logabsdet_jacobian(tr, xg)
gr = Zygote.jacobian(x -> Bijectors.with_logabsdet_jacobian(tr, x)[1], xg)
With error:
julia> gr = Zygote.jacobian(x -> Bijectors.with_logabsdet_jacobian(tr, xg)[1], x)
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#gpu_broadcast_kernel_linear#38")(::KernelAbstractions.CompilerMetadata{…}, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}) failed
KernelError: passing non-bitstype argument
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}}, Zygote.var"#683#687"{Zygote.Context{false}, typeof(first)}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}}, 1}, Tuple{Bool}, Tuple{Int64}}}}, which is not a bitstype:
.f is of type Zygote.var"#683#687"{Zygote.Context{false}, typeof(first)} which is not isbits.
.cx is of type Zygote.Context{false} which is not isbits.
.cache is of type Union{Nothing, IdDict{Any, Any}} which is not isbits.
Only bitstypes, which are "plain data" types that are immutable
and contain no references to other values, can be used in GPU kernels.
For more information, see the `Base.isbitstype` function.
Stacktrace:
[1] check_invocation(job::GPUCompiler.CompilerJob)
@ GPUCompiler ~/scratch/twutz/julia_cluster_depots/packages/GPUCompiler/Nxf8r/src/validation.jl:108
...
[63] jacobian(f::Function, args::Vector{Float64})
@ Zygote ~/scratch/twutz/julia_cluster_depots/packages/Zygote/59YyM/src/lib/grad.jl:168
Environment:
[76274a88] Bijectors v0.15.4
[052768ef] CUDA v5.6.1
[e88e6eb3] Zygote v0.7.2
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels