Skip to content

Commit d71968f

Browse files
committed
Rework Strided wrapping
1 parent 2e30404 commit d71968f

File tree

4 files changed

+86
-29
lines changed

4 files changed

+86
-29
lines changed

ext/TensorOperationsBumperExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@ module TensorOperationsBumperExt
33
using TensorOperations
44
using Bumper
55

6+
# Hack to normalize StridedView type to avoid too many specializations
7+
# This is allowed because bumper ensures that the pointer won't be GC'd
8+
# and we never return `parent(SV)` anyways.
9+
function TensorOperations.wrap_stridedview(A::Bumper.UnsafeArray)
10+
mem_A = Base.unsafe_wrap(Memory{eltype(A)}, pointer(A), length(A))
11+
return TensorOperations.StridedView(mem_A, size(A), strides(A), 0, identity)
12+
end
13+
614
function TensorOperations.tensoralloc(::Type{A}, structure, ::Val{istemp},
715
buf::Union{SlabBuffer,AllocBuffer}) where {A<:AbstractArray,
816
istemp}

src/implementation/blascontract.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,18 @@ function _blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
4747
flagC = isblasdestination(C, ipAB)
4848
if flagC
4949
C_ = C
50-
_unsafe_blas_contract!(C_, A_, pA, B_, pB, ipAB, α, β)
50+
_unsafe_blas_contract!(wrap_stridedview(C_),
51+
wrap_stridedview(A_), pA,
52+
wrap_stridedview(B_), pB,
53+
ipAB, α, β)
5154
else
52-
C_ = SV(tensoralloc_add(TC, C, ipAB, false, Val(true), allocator))
53-
_unsafe_blas_contract!(C_, A_, pA, B_, pB, trivialpermutation(ipAB),
54-
one(TC), zero(TC))
55+
C_ = tensoralloc_add(TC, C, ipAB, false, Val(true), allocator)
56+
_unsafe_blas_contract!(wrap_stridedview(C_),
57+
wrap_stridedview(A_), pA,
58+
wrap_stridedview(B_), pB,
59+
trivialpermutation(ipAB), one(TC), zero(TC))
5560
tensoradd!(C, C_, pAB, false, α, β, backend, allocator)
56-
tensorfree!(C_.parent, allocator)
61+
tensorfree!(C_, allocator)
5762
end
5863
flagA || tensorfree!(A_.parent, allocator)
5964
flagB || tensorfree!(B_.parent, allocator)
@@ -85,8 +90,7 @@ end
8590
flagA = isblascontractable(A, pA) && eltype(A) == TC
8691
if !flagA
8792
A_ = tensoralloc_add(TC, A, pA, false, Val(true), allocator)
88-
Anew = SV(A_, size(A_), strides(A_), 0, A.op)
89-
Anew = tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator)
93+
Anew = tensoradd!(A_, A, pA, false, One(), Zero(), backend, allocator)
9094
pAnew = trivialpermutation(pA)
9195
else
9296
Anew = A

src/implementation/diagonal.jl

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,22 @@ function tensorcontract!(C::AbstractArray,
1111
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
1212

1313
if conjA && conjB
14-
_diagtensorcontract!(SV(C), conj(SV(A)), pA, conj(SV(B.diag)), pB, pAB, α, β)
14+
_diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA,
15+
conj(wrap_stridedview(B.diag)), pB,
16+
pAB, α, β)
1517
elseif conjA
16-
_diagtensorcontract!(SV(C), conj(SV(A)), pA, SV(B.diag), pB, pAB, α, β)
18+
_diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA,
19+
wrap_stridedview(B.diag),
20+
pB, pAB, α,
21+
β)
1722
elseif conjB
18-
_diagtensorcontract!(SV(C), SV(A), pA, conj(SV(B.diag)), pB, pAB, α, β)
23+
_diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA,
24+
conj(wrap_stridedview(B.diag)),
25+
pB, pAB, α,
26+
β)
1927
else
20-
_diagtensorcontract!(SV(C), SV(A), pA, SV(B.diag), pB, pAB, α, β)
28+
_diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA,
29+
wrap_stridedview(B.diag), pB, pAB, α, β)
2130
end
2231
return C
2332
end
@@ -41,13 +50,17 @@ function tensorcontract!(C::AbstractArray,
4150
TupleTools.getindices(indCinoBA, tpAB[2]))
4251

4352
if conjA && conjB
44-
_diagtensorcontract!(SV(C), conj(SV(B)), rpB, conj(SV(A.diag)), rpA, rpAB, α, β)
53+
_diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(B)), rpB,
54+
conj(wrap_stridedview(A.diag)), rpA, rpAB, α, β)
4555
elseif conjA
46-
_diagtensorcontract!(SV(C), SV(B), rpB, conj(SV(A.diag)), rpA, rpAB, α, β)
56+
_diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(B), rpB,
57+
conj(wrap_stridedview(A.diag)), rpA, rpAB, α, β)
4758
elseif conjB
48-
_diagtensorcontract!(SV(C), conj(SV(B)), rpB, SV(A.diag), rpA, rpAB, α, β)
59+
_diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(B)), rpB,
60+
wrap_stridedview(A.diag), rpA, rpAB, α, β)
4961
else
50-
_diagtensorcontract!(SV(C), SV(B), rpB, SV(A.diag), rpA, rpAB, α, β)
62+
_diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(B), rpB,
63+
wrap_stridedview(A.diag), rpA, rpAB, α, β)
5164
end
5265
return C
5366
end
@@ -62,13 +75,17 @@ function tensorcontract!(C::AbstractArray,
6275
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
6376

6477
if conjA && conjB
65-
_diagdiagcontract!(SV(C), conj(SV(A.diag)), pA, conj(SV(B.diag)), pB, pAB, α, β)
78+
_diagdiagcontract!(wrap_stridedview(C), conj(wrap_stridedview(A.diag)), pA,
79+
conj(wrap_stridedview(B.diag)), pB, pAB, α, β)
6680
elseif conjA
67-
_diagdiagcontract!(SV(C), conj(SV(A.diag)), pA, SV(B.diag), pB, pAB, α, β)
81+
_diagdiagcontract!(wrap_stridedview(C), conj(wrap_stridedview(A.diag)), pA,
82+
wrap_stridedview(B.diag), pB, pAB, α, β)
6883
elseif conjB
69-
_diagdiagcontract!(SV(C), SV(A.diag), pA, conj(SV(B.diag)), pB, pAB, α, β)
84+
_diagdiagcontract!(wrap_stridedview(C), wrap_stridedview(A.diag), pA,
85+
conj(wrap_stridedview(B.diag)), pB, pAB, α, β)
7086
else
71-
_diagdiagcontract!(SV(C), SV(A.diag), pA, SV(B.diag), pB, pAB, α, β)
87+
_diagdiagcontract!(wrap_stridedview(C), wrap_stridedview(A.diag), pA,
88+
wrap_stridedview(B.diag), pB, pAB, α, β)
7289
end
7390
return C
7491
end

src/implementation/strided.jl

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,38 @@ end
3838
#-------------------------------------------------------------------------------------------
3939
# Force strided implementation on AbstractArray instances with Strided backend
4040
#-------------------------------------------------------------------------------------------
41-
const SV = StridedView
41+
42+
# we normalize the parent types here to avoid too many specializations
43+
# this is allowed because we never return `parent(SV)`, so we can safely wrap anything
44+
# that represents the same data
45+
"""
46+
wrap_stridedview(A::AbstractArray)
47+
48+
Wrap any compatible array into a `StridedView` for the implementation.
49+
Additionally, we normalize the parent types to avoid having to have too many specializations.
50+
This is allowed because we never return `parent(SV)`, so we can safely wrap anything
51+
that represents the same data.
52+
"""
53+
wrap_stridedview(A::AbstractArray) = StridedView(reshape(A, length(A)),
54+
size(A), strides(A), 0, identity)
55+
wrap_stridedview(A::StridedView) = A
56+
@static if isdefined(Core, :Memory)
57+
# For Arrays: we simply use the memory directly
58+
# TODO: can we also do this for views?
59+
wrap_stridedview(A::Array) = StridedView(A.ref.mem, size(A), strides(A), 0, identity)
60+
end
61+
4262
function tensoradd!(C::AbstractArray,
4363
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
4464
α::Number, β::Number,
4565
backend::StridedBackend, allocator=DefaultAllocator())
4666
# resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on
4767
if conjA
48-
stridedtensoradd!(SV(C), conj(SV(A)), pA, α, β, backend, allocator)
68+
stridedtensoradd!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA, α, β, backend,
69+
allocator)
4970
else
50-
stridedtensoradd!(SV(C), SV(A), pA, α, β, backend, allocator)
71+
stridedtensoradd!(wrap_stridedview(C), wrap_stridedview(A), pA, α, β, backend,
72+
allocator)
5173
end
5274
return C
5375
end
@@ -58,9 +80,11 @@ function tensortrace!(C::AbstractArray,
5880
backend::StridedBackend, allocator=DefaultAllocator())
5981
# resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on
6082
if conjA
61-
stridedtensortrace!(SV(C), conj(SV(A)), p, q, α, β, backend, allocator)
83+
stridedtensortrace!(wrap_stridedview(C), conj(wrap_stridedview(A)), p, q, α, β,
84+
backend, allocator)
6285
else
63-
stridedtensortrace!(SV(C), SV(A), p, q, α, β, backend, allocator)
86+
stridedtensortrace!(wrap_stridedview(C), wrap_stridedview(A), p, q, α, β, backend,
87+
allocator)
6488
end
6589
return C
6690
end
@@ -73,16 +97,20 @@ function tensorcontract!(C::AbstractArray,
7397
backend::StridedBackend, allocator=DefaultAllocator())
7498
# resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on
7599
if conjA && conjB
76-
stridedtensorcontract!(SV(C), conj(SV(A)), pA, conj(SV(B)), pB, pAB, α, β,
100+
stridedtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA,
101+
conj(wrap_stridedview(B)), pB, pAB, α, β,
77102
backend, allocator)
78103
elseif conjA
79-
stridedtensorcontract!(SV(C), conj(SV(A)), pA, SV(B), pB, pAB, α, β,
104+
stridedtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA,
105+
wrap_stridedview(B), pB, pAB, α, β,
80106
backend, allocator)
81107
elseif conjB
82-
stridedtensorcontract!(SV(C), SV(A), pA, conj(SV(B)), pB, pAB, α, β,
108+
stridedtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA,
109+
conj(wrap_stridedview(B)), pB, pAB, α, β,
83110
backend, allocator)
84111
else
85-
stridedtensorcontract!(SV(C), SV(A), pA, SV(B), pB, pAB, α, β,
112+
stridedtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA,
113+
wrap_stridedview(B), pB, pAB, α, β,
86114
backend, allocator)
87115
end
88116
return C
@@ -130,7 +158,7 @@ function stridedtensortrace!(C::StridedView,
130158
newstrides = (strideA.(linearize(p))..., (strideA.(q[1]) .+ strideA.(q[2]))...)
131159
newsize = (size(C)..., tracesize...)
132160

133-
A′ = SV(A.parent, newsize, newstrides, A.offset, A.op)
161+
A′ = StridedView(A.parent, newsize, newstrides, A.offset, A.op)
134162
Strided._mapreducedim!(Scaler(α), Adder(), Scaler(β), newsize, (C, A′))
135163
return C
136164
end

0 commit comments

Comments
 (0)