Skip to content

support transformation on CUDA gpu? #363

@bgctw

Description

@bgctw

Will Bijectors support transformations on CuArrays?
Maybe a subset of transformations?

Currently, the MWE that applies a Stacked transformation on a CuArray works, but computing the logabsdet of the Jacobian fails:

#]activate --temp
#add Bijectors, Zygote, CUDA
using Bijectors
using Zygote
x = [-2, 3.0, 2.1, 2.2]
ranges = [1:1, 2:2, 3:4]
tr = Stacked((elementwise(identity), elementwise(exp), elementwise(exp)), ranges)

using CUDA
xg = CuArray(x)
y = tr(xg)
gr = Zygote.jacobian(tr, xg) # works nicely

y, logjac = Bijectors.with_logabsdet_jacobian(tr, xg)
gr = Zygote.jacobian(x ->  Bijectors.with_logabsdet_jacobian(tr, x)[1], xg)

With error:

julia> gr = Zygote.jacobian(x ->  Bijectors.with_logabsdet_jacobian(tr, xg)[1], x)
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#gpu_broadcast_kernel_linear#38")(::KernelAbstractions.CompilerMetadata{…}, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}) failed
KernelError: passing non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}}, Zygote.var"#683#687"{Zygote.Context{false}, typeof(first)}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}}, 1}, Tuple{Bool}, Tuple{Int64}}}}, which is not a bitstype:
  .f is of type Zygote.var"#683#687"{Zygote.Context{false}, typeof(first)} which is not isbits.
    .cx is of type Zygote.Context{false} which is not isbits.
      .cache is of type Union{Nothing, IdDict{Any, Any}} which is not isbits.


Only bitstypes, which are "plain data" types that are immutable
and contain no references to other values, can be used in GPU kernels.
For more information, see the `Base.isbitstype` function.

Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/scratch/twutz/julia_cluster_depots/packages/GPUCompiler/Nxf8r/src/validation.jl:108
...
 [63] jacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/scratch/twutz/julia_cluster_depots/packages/Zygote/59YyM/src/lib/grad.jl:168

Environment:

  [76274a88] Bijectors v0.15.4
  [052768ef] CUDA v5.6.1
  [e88e6eb3] Zygote v0.7.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions