Skip to content

Commit e5651a0

Browse files
authored
update strided implementation (#191)
* update strided implementation * move blascontract * fix forgotten conj argument * add lts to ci
1 parent 427403e commit e5651a0

File tree

4 files changed

+169
-164
lines changed

4 files changed

+169
-164
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
matrix:
2323
version:
2424
- '1.8' # lowest supported version
25+
- '1.10' # julia lts
2526
- '1' # automatically expands to the latest stable 1.x release of Julia
2627
os:
2728
- ubuntu-latest

src/TensorOperations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ include("implementation/functions.jl")
5959
include("implementation/ncon.jl")
6060
include("implementation/abstractarray.jl")
6161
include("implementation/strided.jl")
62+
include("implementation/blascontract.jl")
6263
include("implementation/diagonal.jl")
6364
include("implementation/base.jl")
6465
include("implementation/indices.jl")

src/implementation/blascontract.jl

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# general implementation for backends that implement tensor contractions by permuting and
2+
# reshaping the input tensors and then calling a BLAS routine to perform the contraction
3+
4+
# all of the following methods expect that basic argument checks on dimensionality and
5+
# permutation validity have already been performed
6+
function blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
7+
rpA = reverse(pA)
8+
rpB = reverse(pB)
9+
indCinoBA = let N₁ = numout(pA), N₂ = numin(pB)
10+
map(n -> ifelse(n > N₁, n - N₁, n + N₂), linearize(pAB))
11+
end
12+
tpAB = trivialpermutation(pAB)
13+
rpAB = (TupleTools.getindices(indCinoBA, tpAB[1]),
14+
TupleTools.getindices(indCinoBA, tpAB[2]))
15+
16+
if contract_memcost(C, A, pA, B, pB, pAB) <= contract_memcost(C, B, rpB, A, rpA, rpAB)
17+
return _blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
18+
else
19+
return _blas_contract!(C, B, rpB, A, rpA, rpAB, α, β, backend, allocator)
20+
end
21+
end
22+
# specialised fast path for matrix matrix multiplication
23+
function blas_contract!(C::StridedView{T,2},
24+
A::StridedView{T,2}, pA::Index2Tuple{1,1},
25+
B::StridedView{T,2}, pB::Index2Tuple{1,1},
26+
pAB::Index2Tuple{1,1},
27+
α::Number, β::Number,
28+
backend, allocator) where {T}
29+
A′ = pA == ((1,), (2,)) ? A : transpose(A)
30+
B′ = pB == ((1,), (2,)) ? B : transpose(B)
31+
if pAB == ((1,), (2,))
32+
mul!(C, A′, B′, α, β)
33+
elseif pAB == ((2,), (1,))
34+
mul!(C, transpose(B′), transpose(A′), α, β)
35+
end
36+
return C
37+
end
38+
39+
# implement necessary permutations
40+
function _blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
41+
TC = eltype(C)
42+
43+
A_, pA, flagA = makeblascontractable(A, pA, TC, backend, allocator)
44+
B_, pB, flagB = makeblascontractable(B, pB, TC, backend, allocator)
45+
46+
ipAB = oindABinC(pAB, pA, pB)
47+
flagC = isblasdestination(C, ipAB)
48+
if flagC
49+
C_ = C
50+
_unsafe_blas_contract!(C_, A_, pA, B_, pB, ipAB, α, β)
51+
else
52+
C_ = SV(tensoralloc_add(TC, C, ipAB, false, Val(true), allocator))
53+
_unsafe_blas_contract!(C_, A_, pA, B_, pB, trivialpermutation(ipAB),
54+
one(TC), zero(TC))
55+
tensoradd!(C, C_, pAB, false, α, β, backend, allocator)
56+
tensorfree!(C_.parent, allocator)
57+
end
58+
flagA || tensorfree!(A_.parent, allocator)
59+
flagB || tensorfree!(B_.parent, allocator)
60+
return C
61+
end
62+
63+
# perform the actual contraction, assuming it can be done as matrix multiplication by simply
64+
# reshaping without any further allocations
65+
function _unsafe_blas_contract!(C::StridedView{T},
66+
A::StridedView{T}, pA,
67+
B::StridedView{T}, pB,
68+
pAB, α, β) where {T<:BlasFloat}
69+
sizeA = size(A)
70+
sizeB = size(B)
71+
csizeA = TupleTools.getindices(sizeA, pA[2])
72+
csizeB = TupleTools.getindices(sizeB, pB[1])
73+
osizeA = TupleTools.getindices(sizeA, pA[1])
74+
osizeB = TupleTools.getindices(sizeB, pB[2])
75+
76+
mul!(sreshape(permutedims(C, linearize(pAB)), (prod(osizeA), prod(osizeB))),
77+
sreshape(permutedims(A, linearize(pA)), (prod(osizeA), prod(csizeA))),
78+
sreshape(permutedims(B, linearize(pB)), (prod(csizeB), prod(osizeB))),
79+
α, β)
80+
81+
return C
82+
end
83+
84+
@inline function makeblascontractable(A, pA, TC, backend, allocator)
85+
flagA = isblascontractable(A, pA) && eltype(A) == TC
86+
if !flagA
87+
A_ = tensoralloc_add(TC, A, pA, false, Val(true), allocator)
88+
Anew = SV(A_, size(A_), strides(A_), 0, A.op)
89+
Anew = tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator)
90+
pAnew = trivialpermutation(pA)
91+
else
92+
Anew = A
93+
pAnew = pA
94+
end
95+
return Anew, pAnew, flagA
96+
end
97+
98+
function isblascontractable(A::StridedView, p::Index2Tuple)
99+
eltype(A) <: LinearAlgebra.BlasFloat || return false
100+
101+
sizeA = size(A)
102+
stridesA = strides(A)
103+
sizeA1 = TupleTools.getindices(sizeA, p[1])
104+
sizeA2 = TupleTools.getindices(sizeA, p[2])
105+
stridesA1 = TupleTools.getindices(stridesA, p[1])
106+
stridesA2 = TupleTools.getindices(stridesA, p[2])
107+
108+
canfuse1, _, s1 = _canfuse(sizeA1, stridesA1)
109+
canfuse2, _, s2 = _canfuse(sizeA2, stridesA2)
110+
111+
if A.op == conj
112+
return canfuse1 && canfuse2 && s2 == 1
113+
else
114+
return canfuse1 && canfuse2 && (s1 == 1 || s2 == 1)
115+
end
116+
end
117+
118+
function isblasdestination(A::StridedView, p::Index2Tuple)
119+
(eltype(A) <: LinearAlgebra.BlasFloat && A.op == identity) || return false
120+
121+
sizeA = size(A)
122+
stridesA = strides(A)
123+
124+
sizeA1 = TupleTools.getindices(sizeA, p[1])
125+
stridesA1 = TupleTools.getindices(stridesA, p[1])
126+
canfuse1, _, s1 = _canfuse(sizeA1, stridesA1)
127+
(canfuse1 && s1 == 1) || return false
128+
129+
sizeA2 = TupleTools.getindices(sizeA, p[2])
130+
stridesA2 = TupleTools.getindices(stridesA, p[2])
131+
canfuse2, _, _ = _canfuse(sizeA2, stridesA2)
132+
return canfuse2
133+
end
134+
135+
_canfuse(::Dims{0}, ::Dims{0}) = true, 1, 1
136+
_canfuse(dims::Dims{1}, strides::Dims{1}) = true, dims[1], strides[1]
137+
@inline function _canfuse(dims::Dims{N}, strides::Dims{N}) where {N}
138+
if dims[1] == 0
139+
return true, 0, 1
140+
elseif dims[1] == 1
141+
return _canfuse(Base.tail(dims), Base.tail(strides))
142+
else
143+
b, d, s = _canfuse(Base.tail(dims), Base.tail(strides))
144+
if b && (s == dims[1] * strides[1] || d == 1)
145+
dnew = dims[1] * d
146+
return true, dnew, (dnew == 0 || dnew == 1) ? 1 : strides[1]
147+
else
148+
return false, dims[1] * d, strides[1]
149+
end
150+
end
151+
end
152+
153+
function oindABinC(pAB, pA, pB)
154+
ipAB = invperm(linearize(pAB))
155+
oindAinC = TupleTools.getindices(ipAB, trivialpermutation(pA[1]))
156+
oindBinC = TupleTools.getindices(ipAB, numout(pA) .+ trivialpermutation(pB[2]))
157+
return (oindAinC, oindBinC)
158+
end
159+
160+
function contract_memcost(C, A, pA, B, pB, pAB)
161+
ipAB = oindABinC(pAB, pA, pB)
162+
return length(A) * (!isblascontractable(A, pA) || eltype(A) !== eltype(C)) +
163+
length(B) * (!isblascontractable(B, pB) || eltype(B) !== eltype(C)) +
164+
length(C) * !isblasdestination(C, ipAB)
165+
end

src/implementation/strided.jl

Lines changed: 2 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -136,48 +136,14 @@ function stridedtensorcontract!(C::StridedView,
136136
B::StridedView, pB::Index2Tuple,
137137
pAB::Index2Tuple,
138138
α::Number, β::Number,
139-
::StridedBLAS, allocator=DefaultAllocator())
139+
backend::StridedBLAS, allocator=DefaultAllocator())
140140
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
141141
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
142142

143143
(Base.mightalias(C, A) || Base.mightalias(C, B)) &&
144144
throw(ArgumentError("output tensor must not be aliased with input tensor"))
145145

146-
rpA = reverse(pA)
147-
rpB = reverse(pB)
148-
indCinoBA = let N₁ = numout(pA), N₂ = numin(pB)
149-
map(n -> ifelse(n > N₁, n - N₁, n + N₂), linearize(pAB))
150-
end
151-
tpAB = trivialpermutation(pAB)
152-
rpAB = (TupleTools.getindices(indCinoBA, tpAB[1]),
153-
TupleTools.getindices(indCinoBA, tpAB[2]))
154-
if contract_memcost(C, A, pA, B, pB, pAB) <= contract_memcost(C, B, rpB, A, rpA, rpAB)
155-
return blas_contract!(C, A, pA, B, pB, pAB, α, β, allocator)
156-
else
157-
return blas_contract!(C, B, rpB, A, rpA, rpAB, α, β, allocator)
158-
end
159-
return C
160-
end
161-
162-
# reduce overhead for the case where it is just matrix multiplication
163-
function stridedtensorcontract!(C::StridedView{T,2},
164-
A::StridedView{T,2}, pA::Index2Tuple{1,1},
165-
B::StridedView{T,2}, pB::Index2Tuple{1,1},
166-
pAB::Index2Tuple{1,1}, α::Number, β::Number,
167-
::StridedBLAS) where {T}
168-
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
169-
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
170-
171-
(Base.mightalias(C, A) || Base.mightalias(C, B)) &&
172-
throw(ArgumentError("output tensor must not be aliased with input tensor"))
173-
174-
A′ = pA == ((1,), (2,)) ? A : permutedims(A, (pA[1][1], pA[2][1]))
175-
B′ = pB == ((1,), (2,)) ? B : permutedims(B, (pB[1][1], pB[2][1]))
176-
if pAB == ((1,), (2,))
177-
mul!(C, A′, B′, α, β)
178-
elseif pAB == ((2,), (1,))
179-
mul!(C, transpose(B′), transpose(A′), α, β)
180-
end
146+
blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
181147
return C
182148
end
183149

@@ -209,131 +175,3 @@ function stridedtensorcontract!(C::StridedView,
209175
Strided._mapreducedim!(op1, +, op2, tsize, (CS, AS, BS))
210176
return C
211177
end
212-
213-
#-------------------------------------------------------------------------------------------
214-
# StridedViewBLAS contraction implementation
215-
#-------------------------------------------------------------------------------------------
216-
function blas_contract!(C, A, pA, B, pB, pAB, α, β, allocator)
217-
TC = eltype(C)
218-
219-
A_, pA, flagA = makeblascontractable(A, pA, TC, allocator)
220-
B_, pB, flagB = makeblascontractable(B, pB, TC, allocator)
221-
222-
ipAB = oindABinC(pAB, pA, pB)
223-
flagC = isblasdestination(C, ipAB)
224-
if flagC
225-
C_ = C
226-
_unsafe_blas_contract!(C_, A_, pA, B_, pB, ipAB, α, β)
227-
else
228-
C_ = SV(tensoralloc_add(TC, C, ipAB, false, Val(true), allocator))
229-
_unsafe_blas_contract!(C_, A_, pA, B_, pB, trivialpermutation(ipAB),
230-
one(TC), zero(TC))
231-
stridedtensoradd!(C, C_, pAB, α, β, StridedNative(), allocator)
232-
tensorfree!(C_.parent, allocator)
233-
end
234-
flagA || tensorfree!(A_.parent, allocator)
235-
flagB || tensorfree!(B_.parent, allocator)
236-
return C
237-
end
238-
239-
function _unsafe_blas_contract!(C::StridedView{T},
240-
A::StridedView{T}, pA,
241-
B::StridedView{T}, pB,
242-
pAB, α, β) where {T<:BlasFloat}
243-
sizeA = size(A)
244-
sizeB = size(B)
245-
csizeA = TupleTools.getindices(sizeA, pA[2])
246-
csizeB = TupleTools.getindices(sizeB, pB[1])
247-
osizeA = TupleTools.getindices(sizeA, pA[1])
248-
osizeB = TupleTools.getindices(sizeB, pB[2])
249-
250-
mul!(sreshape(permutedims(C, linearize(pAB)), (prod(osizeA), prod(osizeB))),
251-
sreshape(permutedims(A, linearize(pA)), (prod(osizeA), prod(csizeA))),
252-
sreshape(permutedims(B, linearize(pB)), (prod(csizeB), prod(osizeB))),
253-
α, β)
254-
255-
return C
256-
end
257-
258-
@inline function makeblascontractable(A, pA, TC, allocator)
259-
flagA = isblascontractable(A, pA) && eltype(A) == TC
260-
if !flagA
261-
A_ = tensoralloc_add(TC, A, pA, false, Val(true), allocator)
262-
Anew = SV(A_, size(A_), strides(A_), 0, A.op)
263-
Anew = stridedtensoradd!(Anew, A, pA, One(), Zero(), StridedNative(), allocator)
264-
pAnew = trivialpermutation(pA)
265-
else
266-
Anew = A
267-
pAnew = pA
268-
end
269-
return Anew, pAnew, flagA
270-
end
271-
272-
function isblascontractable(A::StridedView, p::Index2Tuple)
273-
eltype(A) <: LinearAlgebra.BlasFloat || return false
274-
275-
sizeA = size(A)
276-
stridesA = strides(A)
277-
sizeA1 = TupleTools.getindices(sizeA, p[1])
278-
sizeA2 = TupleTools.getindices(sizeA, p[2])
279-
stridesA1 = TupleTools.getindices(stridesA, p[1])
280-
stridesA2 = TupleTools.getindices(stridesA, p[2])
281-
282-
canfuse1, _, s1 = _canfuse(sizeA1, stridesA1)
283-
canfuse2, _, s2 = _canfuse(sizeA2, stridesA2)
284-
285-
if A.op == conj
286-
return canfuse1 && canfuse2 && s2 == 1
287-
else
288-
return canfuse1 && canfuse2 && (s1 == 1 || s2 == 1)
289-
end
290-
end
291-
292-
function isblasdestination(A::StridedView, p::Index2Tuple)
293-
(eltype(A) <: LinearAlgebra.BlasFloat && A.op == identity) || return false
294-
295-
sizeA = size(A)
296-
stridesA = strides(A)
297-
298-
sizeA1 = TupleTools.getindices(sizeA, p[1])
299-
stridesA1 = TupleTools.getindices(stridesA, p[1])
300-
canfuse1, _, s1 = _canfuse(sizeA1, stridesA1)
301-
(canfuse1 && s1 == 1) || return false
302-
303-
sizeA2 = TupleTools.getindices(sizeA, p[2])
304-
stridesA2 = TupleTools.getindices(stridesA, p[2])
305-
canfuse2, _, _ = _canfuse(sizeA2, stridesA2)
306-
return canfuse2
307-
end
308-
309-
_canfuse(::Dims{0}, ::Dims{0}) = true, 1, 1
310-
_canfuse(dims::Dims{1}, strides::Dims{1}) = true, dims[1], strides[1]
311-
@inline function _canfuse(dims::Dims{N}, strides::Dims{N}) where {N}
312-
if dims[1] == 0
313-
return true, 0, 1
314-
elseif dims[1] == 1
315-
return _canfuse(Base.tail(dims), Base.tail(strides))
316-
else
317-
b, d, s = _canfuse(Base.tail(dims), Base.tail(strides))
318-
if b && (s == dims[1] * strides[1] || d == 1)
319-
dnew = dims[1] * d
320-
return true, dnew, (dnew == 0 || dnew == 1) ? 1 : strides[1]
321-
else
322-
return false, dims[1] * d, strides[1]
323-
end
324-
end
325-
end
326-
327-
function oindABinC(pAB, pA, pB)
328-
ipAB = invperm(linearize(pAB))
329-
oindAinC = TupleTools.getindices(ipAB, trivialpermutation(pA[1]))
330-
oindBinC = TupleTools.getindices(ipAB, numout(pA) .+ trivialpermutation(pB[2]))
331-
return (oindAinC, oindBinC)
332-
end
333-
334-
function contract_memcost(C, A, pA, B, pB, pAB)
335-
ipAB = oindABinC(pAB, pA, pB)
336-
return length(A) * (!isblascontractable(A, pA) || eltype(A) !== eltype(C)) +
337-
length(B) * (!isblascontractable(B, pB) || eltype(B) !== eltype(C)) +
338-
length(C) * !isblasdestination(C, ipAB)
339-
end

0 commit comments

Comments
 (0)