-
Notifications
You must be signed in to change notification settings - Fork 269
Expand file tree
/
Copy pathinterfaces.jl
More file actions
66 lines (56 loc) · 2.39 KB
/
interfaces.jl
File metadata and controls
66 lines (56 loc) · 2.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# interfacing with other packages
## Base
# function Base.:(+)(A::CuTensor, B::CuTensor)
# α = convert(eltype(A), 1.0)
# γ = convert(eltype(B), 1.0)
# C = similar(B)
# elementwise_binary_execute!(α, A.data, A.inds, CUTENSOR_OP_IDENTITY,
# γ, B.data, B.inds, CUTENSOR_OP_IDENTITY,
# C.data, C.inds, CUTENSOR_OP_ADD)
# C
# end
# function Base.:(-)(A::CuTensor, B::CuTensor)
# α = convert(eltype(A), 1.0)
# γ = convert(eltype(B), -1.0)
# C = similar(B)
# elementwise_binary_execute!(α, A.data, A.inds, CUTENSOR_OP_IDENTITY,
# γ, B.data, B.inds, CUTENSOR_OP_IDENTITY,
# C.data, C.inds, CUTENSOR_OP_ADD)
# C
# end
## For now call contract in ITensor and rely on UnallocatedArrays to make
## C in a dry-run of the contraction.
# function Base.:(*)(A::CuTensorBS, B::CuTensorBs)
# tC = promote_type(eltype(A), eltype(B))
# A_uniqs = [(idx, i) for (idx, i) in enumerate(A.inds) if !(i in B.inds)]
# B_uniqs = [(idx, i) for (idx, i) in enumerate(B.inds) if !(i in A.inds)]
# A_sizes = map(x->size(A,x[1]), A_uniqs)
# B_sizes = map(x->size(B,x[1]), B_uniqs)
# A_inds = map(x->x[2], A_uniqs)
# B_inds = map(x->x[2], B_uniqs)
# C = CuTensor(CUDA.zeros(tC, Dims(vcat(A_sizes, B_sizes))), vcat(A_inds, B_inds))
# return mul!(C, A, B)
# end
## LinearAlgebra
using LinearAlgebra
# function LinearAlgebra.axpy!(a, X::CuTensor, Y::CuTensor)
# elementwise_binary_execute!(a, X.data, X.inds, CUTENSOR_OP_IDENTITY,
# one(eltype(Y)), Y.data, Y.inds, CUTENSOR_OP_IDENTITY,
# Y.data, Y.inds, CUTENSOR_OP_ADD)
# return Y
# end
# function LinearAlgebra.axpby!(a, X::CuTensor, b, Y::CuTensor)
# elementwise_binary_execute!(a, X.data, X.inds, CUTENSOR_OP_IDENTITY,
# b, Y.data, Y.inds, CUTENSOR_OP_IDENTITY,
# Y.data, Y.inds, CUTENSOR_OP_ADD)
# return Y
# end
function LinearAlgebra.mul!(C::CuTensorBS, A::CuTensorBS, B::CuTensorBS, α::Number, β::Number)
contract!(α,
A, A.inds, CUTENSOR_OP_IDENTITY,
B, B.inds, CUTENSOR_OP_IDENTITY,
β,
C, C.inds, CUTENSOR_OP_IDENTITY,
CUTENSOR_OP_IDENTITY; jit=CUTENSOR_JIT_MODE_DEFAULT)
return C
end