Skip to content

Commit 16a5ba0

Browse files
format using runic
1 parent cd439fb commit 16a5ba0

39 files changed

+1606
-1027
lines changed

docs/make.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11
using Documenter
22
using TensorOperations
33

4-
makedocs(; modules=[TensorOperations],
5-
sitename="TensorOperations.jl",
6-
authors="Jutho Haegeman",
7-
format=Documenter.HTML(; prettyurls=get(ENV, "CI", nothing) == "true"),
8-
pages=["Home" => "index.md",
9-
"Manual" => ["man/indexnotation.md",
10-
"man/functions.md",
11-
"man/interface.md",
12-
"man/backends.md",
13-
"man/autodiff.md",
14-
"man/implementation.md",
15-
"man/precompilation.md"],
16-
"Index" => "index/index.md"])
4+
makedocs(;
5+
modules = [TensorOperations],
6+
sitename = "TensorOperations.jl",
7+
authors = "Jutho Haegeman",
8+
format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"),
9+
pages = [
10+
"Home" => "index.md",
11+
"Manual" => [
12+
"man/indexnotation.md",
13+
"man/functions.md",
14+
"man/interface.md",
15+
"man/backends.md",
16+
"man/autodiff.md",
17+
"man/implementation.md",
18+
"man/precompilation.md",
19+
],
20+
"Index" => "index/index.md",
21+
]
22+
)
1723

1824
# Documenter can also automatically deploy documentation to gh-pages.
1925
# See "Hosting Documentation" and deploydocs() in the Documenter manual
2026
# for more information.
21-
deploydocs(; repo="github.com/Jutho/TensorOperations.jl.git")
27+
deploydocs(; repo = "github.com/Jutho/TensorOperations.jl.git")

ext/TensorOperationsBumperExt.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ module TensorOperationsBumperExt
33
using TensorOperations
44
using Bumper
55

6-
function TensorOperations.tensoralloc(::Type{A}, structure, ::Val{istemp},
7-
buf::Union{SlabBuffer,AllocBuffer}) where {A<:AbstractArray,
8-
istemp}
6+
function TensorOperations.tensoralloc(
7+
::Type{A}, structure, ::Val{istemp},
8+
buf::Union{SlabBuffer, AllocBuffer}
9+
) where {
10+
A <: AbstractArray,
11+
istemp,
12+
}
913
# TODO: remove the `ndims` check if this is fixed in Bumper / StrideArraysCore
1014
if istemp && ndims(A) > 0
1115
return Bumper.alloc!(buf, eltype(A), structure...)
@@ -14,11 +18,15 @@ function TensorOperations.tensoralloc(::Type{A}, structure, ::Val{istemp},
1418
end
1519
end
1620

17-
function TensorOperations.blas_contract!(C, A, pA, B, pB, pAB, α, β,
18-
backend, allocator::Union{SlabBuffer,AllocBuffer})
21+
function TensorOperations.blas_contract!(
22+
C, A, pA, B, pB, pAB, α, β,
23+
backend, allocator::Union{SlabBuffer, AllocBuffer}
24+
)
1925
@no_escape allocator begin
20-
C = Base.@invoke TensorOperations.blas_contract!(C, A, pA, B, pB, pAB, α, β,
21-
backend, allocator::Any)
26+
C = Base.@invoke TensorOperations.blas_contract!(
27+
C, A, pA, B, pB, pAB, α, β,
28+
backend, allocator::Any
29+
)
2230
end
2331
return C
2432
end
@@ -32,8 +40,12 @@ function TensorOperations._butensor(src, ex...)
3240
newex = quote
3341
$buf_sym = $(Expr(:call, GlobalRef(Bumper, :default_buffer)))
3442
$cp_sym = $(Expr(:call, GlobalRef(Bumper, :checkpoint_save), buf_sym))
35-
$res_sym = $(Expr(:macrocall, GlobalRef(TensorOperations, Symbol("@tensor")),
36-
src, :(allocator = $buf_sym), ex...))
43+
$res_sym = $(
44+
Expr(
45+
:macrocall, GlobalRef(TensorOperations, Symbol("@tensor")),
46+
src, :(allocator = $buf_sym), ex...
47+
)
48+
)
3749
$(Expr(:call, GlobalRef(Bumper, :checkpoint_restore!), cp_sym))
3850
$res_sym
3951
end

ext/TensorOperationsChainRulesCoreExt.jl

Lines changed: 106 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,17 @@ trivtuple(N) = ntuple(identity, N)
2121

2222
# Cannot free intermediate tensors when using AD
2323
# Thus we change the forward passes: `istemp=false` and `tensorfree!` is a no-op
24-
function ChainRulesCore.rrule(::typeof(TensorOperations.tensorfree!),
25-
allocator=DefaultAllocator())
24+
function ChainRulesCore.rrule(
25+
::typeof(TensorOperations.tensorfree!),
26+
allocator = DefaultAllocator()
27+
)
2628
tensorfree!_pullback(Δargs...) = (NoTangent(), NoTangent())
2729
return nothing, tensorfree!_pullback
2830
end
29-
function ChainRulesCore.rrule(::typeof(TensorOperations.tensoralloc), ttype, structure,
30-
istemp, allocator=DefaultAllocator())
31+
function ChainRulesCore.rrule(
32+
::typeof(TensorOperations.tensoralloc), ttype, structure,
33+
istemp, allocator = DefaultAllocator()
34+
)
3135
output = TensorOperations.tensoralloc(ttype, structure, Val(false), allocator)
3236
function tensoralloc_pullback(Δargs...)
3337
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
@@ -68,11 +72,13 @@ end
6872
# α::Number, β::Number)
6973
# return _rrule_tensoradd!(C, A, pA, conjA, α, β, ())
7074
# end
71-
function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
72-
C,
73-
A, pA::Index2Tuple, conjA::Bool,
74-
α::Number, β::Number,
75-
ba...)
75+
function ChainRulesCore.rrule(
76+
::typeof(TensorOperations.tensoradd!),
77+
C,
78+
A, pA::Index2Tuple, conjA::Bool,
79+
α::Number, β::Number,
80+
ba...
81+
)
7682
return _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
7783
end
7884
function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
@@ -93,16 +99,24 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
9399
return projectA(_dA)
94100
end
95101
= @thunk let
96-
_dα = tensorscalar(tensorcontract(A, ((), linearize(pA)), !conjA,
97-
ΔC, (trivtuple(numind(pA)), ()), false,
98-
((), ()), One(), ba...))
102+
_dα = tensorscalar(
103+
tensorcontract(
104+
A, ((), linearize(pA)), !conjA,
105+
ΔC, (trivtuple(numind(pA)), ()), false,
106+
((), ()), One(), ba...
107+
)
108+
)
99109
return projectα(_dα)
100110
end
101111
= @thunk let
102112
# TODO: consider using `inner`
103-
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pA))), true,
104-
ΔC, (trivtuple(numind(pA)), ()), false,
105-
((), ()), One(), ba...))
113+
_dβ = tensorscalar(
114+
tensorcontract(
115+
C, ((), trivtuple(numind(pA))), true,
116+
ΔC, (trivtuple(numind(pA)), ()), false,
117+
((), ()), One(), ba...
118+
)
119+
)
106120
return projectβ(_dβ)
107121
end
108122
dba = map(_ -> NoTangent(), ba)
@@ -141,13 +155,15 @@ end
141155
# α::Number, β::Number)
142156
# return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ())
143157
# end
144-
function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
145-
C,
146-
A, pA::Index2Tuple, conjA::Bool,
147-
B, pB::Index2Tuple, conjB::Bool,
148-
pAB::Index2Tuple,
149-
α::Number, β::Number,
150-
ba...)
158+
function ChainRulesCore.rrule(
159+
::typeof(TensorOperations.tensorcontract!),
160+
C,
161+
A, pA::Index2Tuple, conjA::Bool,
162+
B, pB::Index2Tuple, conjB::Bool,
163+
pAB::Index2Tuple,
164+
α::Number, β::Number,
165+
ba...
166+
)
151167
return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
152168
end
153169
function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
@@ -162,52 +178,66 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
162178
function pullback(ΔC′)
163179
ΔC = unthunk(ΔC′)
164180
ipAB = invperm(linearize(pAB))
165-
pΔC = (TupleTools.getindices(ipAB, trivtuple(numout(pA))),
166-
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))))
181+
pΔC = (
182+
TupleTools.getindices(ipAB, trivtuple(numout(pA))),
183+
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))),
184+
)
167185
dC = @thunk projectC(scale(ΔC, conj(β)))
168186
dA = @thunk let
169187
ipA = (invperm(linearize(pA)), ())
170188
conjΔC = conjA
171189
conjB′ = conjA ? conjB : !conjB
172190
_dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α)))
173-
_dA = tensorcontract!(_dA,
174-
ΔC, pΔC, conjΔC,
175-
B, reverse(pB), conjB′,
176-
ipA,
177-
conjA ? α : conj(α), Zero(), ba...)
191+
_dA = tensorcontract!(
192+
_dA,
193+
ΔC, pΔC, conjΔC,
194+
B, reverse(pB), conjB′,
195+
ipA,
196+
conjA ? α : conj(α), Zero(), ba...
197+
)
178198
return projectA(_dA)
179199
end
180200
dB = @thunk let
181201
ipB = (invperm(linearize(pB)), ())
182202
conjΔC = conjB
183203
conjA′ = conjB ? conjA : !conjA
184204
_dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α)))
185-
_dB = tensorcontract!(_dB,
186-
A, reverse(pA), conjA′,
187-
ΔC, pΔC, conjΔC,
188-
ipB,
189-
conjB ? α : conj(α), Zero(), ba...)
205+
_dB = tensorcontract!(
206+
_dB,
207+
A, reverse(pA), conjA′,
208+
ΔC, pΔC, conjΔC,
209+
ipB,
210+
conjB ? α : conj(α), Zero(), ba...
211+
)
190212
return projectB(_dB)
191213
end
192214
= @thunk let
193215
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
194216
# TODO: consider using `inner`
195-
_dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(pAB))), true,
196-
ΔC, (trivtuple(numind(pAB)), ()), false,
197-
((), ()), One(), ba...))
217+
_dα = tensorscalar(
218+
tensorcontract(
219+
C_αβ, ((), trivtuple(numind(pAB))), true,
220+
ΔC, (trivtuple(numind(pAB)), ()), false,
221+
((), ()), One(), ba...
222+
)
223+
)
198224
return projectα(_dα)
199225
end
200226
= @thunk let
201227
# TODO: consider using `inner`
202-
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pAB))), true,
203-
ΔC, (trivtuple(numind(pAB)), ()), false,
204-
((), ()), One(), ba...))
228+
_dβ = tensorscalar(
229+
tensorcontract(
230+
C, ((), trivtuple(numind(pAB))), true,
231+
ΔC, (trivtuple(numind(pAB)), ()), false,
232+
((), ()), One(), ba...
233+
)
234+
)
205235
return projectβ(_dβ)
206236
end
207237
dba = map(_ -> NoTangent(), ba)
208238
return NoTangent(), dC,
209-
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(),
210-
NoTangent(), dα, dβ, dba...
239+
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(),
240+
NoTangent(), dα, dβ, dba...
211241
end
212242

213243
return C′, pullback
@@ -232,10 +262,12 @@ end
232262
# α::Number, β::Number)
233263
# return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ())
234264
# end
235-
function ChainRulesCore.rrule(::typeof(tensortrace!), C,
236-
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
237-
α::Number, β::Number,
238-
ba...)
265+
function ChainRulesCore.rrule(
266+
::typeof(tensortrace!), C,
267+
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
268+
α::Number, β::Number,
269+
ba...
270+
)
239271
return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
240272
end
241273
function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
@@ -252,29 +284,43 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
252284
dA = @thunk let
253285
ip = invperm((linearize(p)..., q[1]..., q[2]...))
254286
Es = map(q[1], q[2]) do i1, i2
255-
return one(TensorOperations.tensoralloc_add(scalartype(A), A,
256-
((i1,), (i2,)), conjA))
287+
return one(
288+
TensorOperations.tensoralloc_add(
289+
scalartype(A), A,
290+
((i1,), (i2,)), conjA
291+
)
292+
)
257293
end
258294
E = _kron(Es, ba)
259295
_dA = zerovector(A, VectorInterface.promote_scale(ΔC, α))
260-
_dA = tensorproduct!(_dA, ΔC, (trivtuple(numind(p)), ()), conjA,
261-
E, ((), trivtuple(numind(q))), conjA,
262-
(ip, ()),
263-
conjA ? α : conj(α), Zero(), ba...)
296+
_dA = tensorproduct!(
297+
_dA, ΔC, (trivtuple(numind(p)), ()), conjA,
298+
E, ((), trivtuple(numind(q))), conjA,
299+
(ip, ()),
300+
conjA ? α : conj(α), Zero(), ba...
301+
)
264302
return projectA(_dA)
265303
end
266304
= @thunk let
267305
C_αβ = tensortrace(A, p, q, false, One(), ba...)
268-
_dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(p))),
269-
!conjA,
270-
ΔC, (trivtuple(numind(p)), ()), false,
271-
((), ()), One(), ba...))
306+
_dα = tensorscalar(
307+
tensorcontract(
308+
C_αβ, ((), trivtuple(numind(p))),
309+
!conjA,
310+
ΔC, (trivtuple(numind(p)), ()), false,
311+
((), ()), One(), ba...
312+
)
313+
)
272314
return projectα(_dα)
273315
end
274316
= @thunk let
275-
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(p))), true,
276-
ΔC, (trivtuple(numind(p)), ()), false,
277-
((), ()), One(), ba...))
317+
_dβ = tensorscalar(
318+
tensorcontract(
319+
C, ((), trivtuple(numind(p))), true,
320+
ΔC, (trivtuple(numind(p)), ()), false,
321+
((), ()), One(), ba...
322+
)
323+
)
278324
return projectβ(_dβ)
279325
end
280326
dba = map(_ -> NoTangent(), ba)
@@ -285,7 +331,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
285331
end
286332

287333
_kron(Es::NTuple{1}, ba) = Es[1]
288-
function _kron(Es::NTuple{N,Any}, ba) where {N}
334+
function _kron(Es::NTuple{N, Any}, ba) where {N}
289335
E1 = Es[1]
290336
E2 = _kron(Base.tail(Es), ba)
291337
p2 = ((), trivtuple(2 * N - 2))

0 commit comments

Comments
 (0)