Skip to content

Commit 42e9c9c

Browse files
committed
fix extension issues on julia 1.11.1
1 parent 3433bc9 commit 42e9c9c

File tree

4 files changed

+97
-74
lines changed

4 files changed

+97
-74
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperations"
22
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
33
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
4-
version = "5.0.2"
4+
version = "5.1.0"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

ext/TensorOperationsBumperExt.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,11 @@ function TensorOperations.tensoralloc(::Type{A}, structure, ::Val{istemp},
1414
end
1515
end
1616

17-
function TensorOperations.blas_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
18-
allocator::Union{SlabBuffer,AllocBuffer})
17+
function TensorOperations.blas_contract!(C, A, pA, B, pB, pAB, α, β,
18+
backend, allocator::Union{SlabBuffer,AllocBuffer})
1919
@no_escape allocator begin
20-
C = Base.@invoke TensorOperations.blas_contract!(C::Any, A::Any, pA::Any,
21-
conjA::Any,
22-
B::Any, pB::Any,
23-
conjB::Any, pAB::Any, α::Any,
24-
β::Any,
25-
allocator::Any)
20+
C = Base.@invoke TensorOperations.blas_contract!(C, A, pA, B, pB, pAB, α, β,
21+
backend, allocator::Any)
2622
end
2723
return C
2824
end

ext/TensorOperationsChainRulesCoreExt.jl

Lines changed: 82 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -46,27 +46,34 @@ end
4646

4747
# The current `rrule` design makes sure that the implementation for custom types does
4848
# not need to support the backend or allocator arguments
49+
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
50+
# C,
51+
# A, pA::Index2Tuple, conjA::Bool,
52+
# α::Number, β::Number,
53+
# backend, allocator)
54+
# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend, allocator))
55+
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
56+
# end
57+
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
58+
# C,
59+
# A, pA::Index2Tuple, conjA::Bool,
60+
# α::Number, β::Number,
61+
# backend)
62+
# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend,))
63+
# return val, ΔC -> (pb(ΔC)..., NoTangent())
64+
# end
65+
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
66+
# C,
67+
# A, pA::Index2Tuple, conjA::Bool,
68+
# α::Number, β::Number)
69+
# return _rrule_tensoradd!(C, A, pA, conjA, α, β, ())
70+
# end
4971
function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
5072
C,
5173
A, pA::Index2Tuple, conjA::Bool,
5274
α::Number, β::Number,
53-
backend, allocator)
54-
val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend, allocator))
55-
return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
56-
end
57-
function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
58-
C,
59-
A, pA::Index2Tuple, conjA::Bool,
60-
α::Number, β::Number,
61-
backend)
62-
val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend,))
63-
return val, ΔC -> (pb(ΔC)..., NoTangent())
64-
end
65-
function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
66-
C,
67-
A, pA::Index2Tuple, conjA::Bool,
68-
α::Number, β::Number)
69-
return _rrule_tensoradd!(C, A, pA, conjA, α, β, ())
75+
ba...)
76+
return _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
7077
end
7178
function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
7279
C′ = tensoradd!(copy(C), A, pA, conjA, α, β, ba...)
@@ -98,40 +105,50 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
98105
((), ()), One(), ba...))
99106
return projectβ(_dβ)
100107
end
101-
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ
108+
dba = map(_ -> NoTangent(), ba)
109+
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
102110
end
103111

104112
return C′, pullback
105113
end
106114

115+
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
116+
# C,
117+
# A, pA::Index2Tuple, conjA::Bool,
118+
# B, pB::Index2Tuple, conjB::Bool,
119+
# pAB::Index2Tuple,
120+
# α::Number, β::Number,
121+
# backend, allocator)
122+
# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
123+
# (backend, allocator))
124+
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
125+
# end
126+
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
127+
# C,
128+
# A, pA::Index2Tuple, conjA::Bool,
129+
# B, pB::Index2Tuple, conjB::Bool,
130+
# pAB::Index2Tuple,
131+
# α::Number, β::Number,
132+
# backend)
133+
# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, (backend,))
134+
# return val, ΔC -> (pb(ΔC)..., NoTangent())
135+
# end
136+
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
137+
# C,
138+
# A, pA::Index2Tuple, conjA::Bool,
139+
# B, pB::Index2Tuple, conjB::Bool,
140+
# pAB::Index2Tuple,
141+
# α::Number, β::Number)
142+
# return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ())
143+
# end
107144
function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
108145
C,
109146
A, pA::Index2Tuple, conjA::Bool,
110147
B, pB::Index2Tuple, conjB::Bool,
111148
pAB::Index2Tuple,
112149
α::Number, β::Number,
113-
backend, allocator)
114-
val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
115-
(backend, allocator))
116-
return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
117-
end
118-
function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
119-
C,
120-
A, pA::Index2Tuple, conjA::Bool,
121-
B, pB::Index2Tuple, conjB::Bool,
122-
pAB::Index2Tuple,
123-
α::Number, β::Number,
124-
backend)
125-
val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, (backend,))
126-
return val, ΔC -> (pb(ΔC)..., NoTangent())
127-
end
128-
function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
129-
C,
130-
A, pA::Index2Tuple, conjA::Bool,
131-
B, pB::Index2Tuple, conjB::Bool,
132-
pAB::Index2Tuple,
133-
α::Number, β::Number)
134-
return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ())
150+
ba...)
151+
return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
135152
end
136153
function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
137154
C′ = tensorcontract!(copy(C), A, pA, conjA, B, pB, conjB, pAB, α, β, ba...)
@@ -187,32 +204,39 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
187204
((), ()), One(), ba...))
188205
return projectβ(_dβ)
189206
end
207+
dba = map(_ -> NoTangent(), ba)
190208
return NoTangent(), dC,
191209
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(),
192-
NoTangent(), dα, dβ
210+
NoTangent(), dα, dβ, dba...
193211
end
194212

195213
return C′, pullback
196214
end
197215

216+
# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
217+
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
218+
# α::Number, β::Number,
219+
# backend, allocator)
220+
# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend, allocator))
221+
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
222+
# end
223+
# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
224+
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
225+
# α::Number, β::Number,
226+
# backend)
227+
# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend,))
228+
# return val, ΔC -> (pb(ΔC)..., NoTangent())
229+
# end
230+
# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
231+
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
232+
# α::Number, β::Number)
233+
# return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ())
234+
# end
198235
function ChainRulesCore.rrule(::typeof(tensortrace!), C,
199236
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
200237
α::Number, β::Number,
201-
backend, allocator)
202-
val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend, allocator))
203-
return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
204-
end
205-
function ChainRulesCore.rrule(::typeof(tensortrace!), C,
206-
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
207-
α::Number, β::Number,
208-
backend)
209-
val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend,))
210-
return val, ΔC -> (pb(ΔC)..., NoTangent())
211-
end
212-
function ChainRulesCore.rrule(::typeof(tensortrace!), C,
213-
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
214-
α::Number, β::Number)
215-
return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ())
238+
ba...)
239+
return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
216240
end
217241
function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
218242
C′ = tensortrace!(copy(C), A, p, q, conjA, α, β, ba...)
@@ -253,7 +277,8 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
253277
((), ()), One(), ba...))
254278
return projectβ(_dβ)
255279
end
256-
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ
280+
dba = map(_ -> NoTangent(), ba)
281+
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...
257282
end
258283

259284
return C′, pullback

ext/TensorOperationscuTENSORExt.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,16 @@ using CUDA.Adapt: adapt
2929
using Strided
3030
using TupleTools: TupleTools as TT
3131

32-
const StridedViewsCUDAExt = @static if isdefined(Base, :get_extension)
33-
Base.get_extension(Strided.StridedViews, :StridedViewsCUDAExt)
34-
else
35-
Strided.StridedViews.StridedViewsCUDAExt
36-
end
37-
isnothing(StridedViewsCUDAExt) && error("StridedViewsCUDAExt not found")
32+
# Disallowed paradigm from Julia 1.11.1 onwards:
33+
# const StridedViewsCUDAExt = @static if isdefined(Base, :get_extension)
34+
# Base.get_extension(Strided.StridedViews, :StridedViewsCUDAExt)
35+
# else
36+
# Strided.StridedViews.StridedViewsCUDAExt
37+
# end
38+
# isnothing(StridedViewsCUDAExt) && error("StridedViewsCUDAExt not found")
39+
40+
# Literal copy of the StridedViewsCUDAExt module
41+
const CuStridedView{T,N,A<:CuArray{T}} = StridedView{T,N,A}
3842

3943
#-------------------------------------------------------------------------------------------
4044
# @cutensor macro
@@ -53,8 +57,6 @@ end
5357
#-------------------------------------------------------------------------------------------
5458
# Backend selection and passing
5559
#-------------------------------------------------------------------------------------------
56-
const CuStridedView = StridedViewsCUDAExt.CuStridedView
57-
5860
# A Base wrapper over `CuArray` will first pass via the `select_backend` methods for
5961
# `AbstractArray` and be converted into a `StridedView` if it satisfies `isstrided`. Hence,
6062
# we only need to capture `CuStridedView` here.

0 commit comments

Comments
 (0)