-
Notifications
You must be signed in to change notification settings - Fork 64
Closed
Description
I tested a bit further on a GPU version for my TensorOperations + Zygote code on CPU, and currently I meet new problem, giving me either Freed reference or nonbits exception when running the program. For most of the case demonstrated here, @tensor
passes and @cutensor
gives me some error. Here you can have an example to the problem. Also I explained a bit on why I used a weird scalar function, but it's not that important for the case. Nevertheless, You can comment on those scalar functions if the behaviors intrigue you
Example code
using LinearAlgebra, TensorOperations
using ChainRulesCore, Zygote
using CUDA, cuTENSOR
function free_ref_or_nonbits(A; use_complex=false, normalize=false)
if use_complex
A = ComplexF64.(A)
end
B = ones(Float64, 2, 2, 2)
C = ones(Float64, 2, 2, 2)
if isa(A, Array)
@tensor D[5,1] := (B[1,2,3] * A[4,2]) * C[3,4,5]
else
@cutensor D[5,1] := (B[1,2,3] * A[4,2]) * C[3,4,5]
end
println(D)
if normalize
normcoef = maximum(abs.(D))
D = D / normcoef
end
println(maximum(abs.(D)))
return maximum(abs.(D))
end
function scalar_func(A; use_sum=0) # a (maybe) similar case
B = ones(Float64, 2, 2)
if isa(A, Array)
@tensor D[1,3] := B[1,2] * A[2,3] # passes
else
@cutensor D[1,3] := B[1,2] * A[2,3] # passes
end
println(D)
if use_sum == 0
println(maximum(abs.(D)))
return maximum(abs.(D))
elseif use_sum == 1
println(sum(D))
return sum(D)
else
println(reduce(+, D))
return reduce(+, D)
end
end
function AD()
######################
## Using 2 flags in `free_ref_or_nonbits()` to demonstrate different exception behaviors,
## and `use_sum` at `scalar_func()` for explaining why I used a cumbersome `maximum(abs)` to return a scalar
##
## Consider using maximum(abs) to return a target scalar (just as free_ref_or_nonbits())
## if !(use_complex && normalize), @tensor will pass and @cutensor will return freed reference error
## if use_complex && normalize, @tensor will pass and @cutensor will return KernelError: passing and using non-bitstype argument at line normcoef = maximum(abs.(D))
##
## At the scalar_func side, consider using maximum(abs) (i.e. use_sum=0), @tensor and @cutensor will pass
## consider using sum() (i.e. use_sum=1), @tensor has MethodError: no method matching StridedViews.StridedView(::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
## and @cutensor will pass
## consider using reduce() (i.e. use_sum=2), @tensor will pass
## and @cutensor will return Zygote: try/catch is not supported at line println(reduce)
######################
f(x) = free_ref_or_nonbits(x; use_complex=true, normalize=true) # Test by changing this function to the given ones
g(x) = gradient(f, x)[1]
initval = Matrix{Float64}([1 0; 0 0])
println("gradient: $(g(initval))")
initval = CuArray(Matrix{Float64}([1 0; 0 0]))
println("gradient: $(g(initval))")
return nothing
end
AD()
Version Info
Status `D:\Julia\depot\environments\v1.9\Project.toml`
⌅ [052768ef] CUDA v5.1.2
[d360d2e6] ChainRulesCore v1.23.0
[4db3bf67] StridedViews v0.2.2
[6aa20fa7] TensorOperations v4.1.1
[409d34a3] VectorInterface v0.4.5
[cd998857] Yota v0.8.5
[e88e6eb3] Zygote v0.6.69
⌃ [011b41b2] cuTENSOR v1.2.1
I am looking for help on what can be done to resolve the freed reference or non_bitstype issue
Metadata
Metadata
Assignees
Labels
No labels