Skip to content

Freed reference problem when combining cuTENSOR and Zygote #169

@hxjz233

Description

@hxjz233

I tested a bit further on a GPU version for my TensorOperations + Zygote code on CPU, and currently I meet new problem, giving me either Freed reference or nonbits exception when running the program. For most of the case demonstrated here, @tensor passes and @cutensor gives me some error. Here you can have an example to the problem. Also I explained a bit on why I used a weird scalar function, but it's not that important for the case. Nevertheless, You can comment on those scalar functions if the behaviors intrigue you

Example code
using LinearAlgebra, TensorOperations
using ChainRulesCore, Zygote
using CUDA, cuTENSOR

function free_ref_or_nonbits(A; use_complex=false, normalize=false)
    if use_complex
        A = ComplexF64.(A)
    end
    B = ones(Float64, 2, 2, 2)
    C = ones(Float64, 2, 2, 2)
    if isa(A, Array)
        @tensor D[5,1] := (B[1,2,3] * A[4,2]) * C[3,4,5]
    else
        @cutensor D[5,1] := (B[1,2,3] * A[4,2]) * C[3,4,5]
    end
    println(D)
    if normalize
        normcoef = maximum(abs.(D))
        D = D / normcoef
    end
    println(maximum(abs.(D)))
    return maximum(abs.(D))
end

function scalar_func(A; use_sum=0)   # a (maybe) similar case
    B = ones(Float64, 2, 2)
    if isa(A, Array)
        @tensor D[1,3] := B[1,2] * A[2,3]           # passes
    else
        @cutensor D[1,3] := B[1,2] * A[2,3]         # passes
    end
    println(D)
    if use_sum == 0
        println(maximum(abs.(D)))
        return maximum(abs.(D))
    elseif use_sum == 1
        println(sum(D))
        return sum(D)
    else 
        println(reduce(+, D))
        return reduce(+, D)
    end
end

function AD()
    ######################
    ##  Using 2 flags in `free_ref_or_nonbits()` to demonstrate different exception behaviors, 
    ##  and `use_sum` at `scalar_func()` for explaining why I used a cumbersome `maximum(abs)` to return a scalar
    ##
    ##  Consider using maximum(abs) to return a target scalar (just as free_ref_or_nonbits())
    ##      if !(use_complex && normalize), @tensor will pass and @cutensor will return freed reference error
    ##      if use_complex && normalize, @tensor will pass and @cutensor will return KernelError: passing and using non-bitstype argument at line normcoef = maximum(abs.(D))
    ##
    ##  At the scalar_func side, consider using maximum(abs) (i.e. use_sum=0), @tensor and @cutensor will pass
    ##  consider using sum() (i.e. use_sum=1), @tensor has MethodError: no method matching StridedViews.StridedView(::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    ##                                         and @cutensor will pass
    ##  consider using reduce() (i.e. use_sum=2), @tensor will pass 
    ##                                            and @cutensor will return Zygote: try/catch is not supported at line println(reduce)
    ######################
    f(x) = free_ref_or_nonbits(x; use_complex=true, normalize=true)   # Test by changing this function to the given ones
    g(x) = gradient(f, x)[1]
    initval = Matrix{Float64}([1 0; 0 0])
    println("gradient: $(g(initval))")              
    initval = CuArray(Matrix{Float64}([1 0; 0 0]))  
    println("gradient: $(g(initval))")              
    return nothing
end

AD()
Version Info
Status `D:\Julia\depot\environments\v1.9\Project.toml`
⌅ [052768ef] CUDA v5.1.2
  [d360d2e6] ChainRulesCore v1.23.0
  [4db3bf67] StridedViews v0.2.2
  [6aa20fa7] TensorOperations v4.1.1
  [409d34a3] VectorInterface v0.4.5
  [cd998857] Yota v0.8.5
  [e88e6eb3] Zygote v0.6.69
⌃ [011b41b2] cuTENSOR v1.2.1

I am looking for help on what can be done to resolve the freed reference or non_bitstype issue

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions