Skip to content

Commit e103313

Browse files
committed
revert strided wrapping
1 parent d784019 commit e103313

File tree

4 files changed

+29
-79
lines changed

4 files changed

+29
-79
lines changed

ext/TensorOperationsBumperExt.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,6 @@ using PrecompileTools
77
using Bumper
88
using Bumper: UnsafeArray
99

10-
# Hack to normalize StridedView type to avoid too many specializations
11-
# This is allowed because bumper ensures that the pointer won't be GC'd
12-
# and we never return `parent(SV)` anyways.
13-
@static if isdefined(Core, :Memory)
14-
function TensorOperations.wrap_stridedview(A::Bumper.UnsafeArray)
15-
mem_A = Base.unsafe_wrap(Memory{eltype(A)}, pointer(A), length(A))
16-
return TensorOperations.StridedView(mem_A, size(A), strides(A), 0, identity)
17-
end
18-
end
19-
2010
function TensorOperations.tensoralloc(::Type{A}, structure, ::Val{istemp},
2111
buf::Union{SlabBuffer,AllocBuffer}) where {A<:AbstractArray,
2212
istemp}

src/implementation/blascontract.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,13 @@ function _blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
4747
flagC = isblasdestination(C, ipAB)
4848
if flagC
4949
C_ = C
50-
_unsafe_blas_contract!(wrap_stridedview(C_),
51-
wrap_stridedview(A_), pA,
52-
wrap_stridedview(B_), pB,
53-
ipAB, α, β)
50+
_unsafe_blas_contract!(C_, A_, pA, B_, pB, ipAB, α, β)
5451
else
55-
C_ = tensoralloc_add(TC, C, ipAB, false, Val(true), allocator)
56-
_unsafe_blas_contract!(wrap_stridedview(C_),
57-
wrap_stridedview(A_), pA,
58-
wrap_stridedview(B_), pB,
59-
trivialpermutation(ipAB), one(TC), zero(TC))
52+
C_ = SV(tensoralloc_add(TC, C, ipAB, false, Val(true), allocator))
53+
_unsafe_blas_contract!(C_, A_, pA, B_, pB, trivialpermutation(ipAB),
54+
one(TC), zero(TC))
6055
tensoradd!(C, C_, pAB, false, α, β, backend, allocator)
61-
tensorfree!(C_, allocator)
56+
tensorfree!(C_.parent, allocator)
6257
end
6358
flagA || tensorfree!(A_, allocator)
6459
flagB || tensorfree!(B_, allocator)
@@ -90,7 +85,8 @@ function makeblascontractable(A, pA, TC, backend, allocator)
9085
flagA = isblascontractable(A, pA) && eltype(A) == TC
9186
if !flagA
9287
A_ = tensoralloc_add(TC, A, pA, false, Val(true), allocator)
93-
Anew = tensoradd!(A_, A, pA, false, One(), Zero(), backend, allocator)
88+
Anew = SV(A_, size(A_), strides(A_), 0, A.op)
89+
Anew = tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator)
9490
pAnew = trivialpermutation(pA)
9591
else
9692
Anew = A

src/implementation/diagonal.jl

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,13 @@ function tensorcontract!(C::AbstractArray,
1111
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
1212

1313
if conjA && conjB
14-
_diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA,
15-
conj(wrap_stridedview(B.diag)), pB,
16-
pAB, α, β)
14+
_diagtensorcontract!(SV(C), conj(SV(A)), pA, conj(SV(B.diag)), pB, pAB, α, β)
1715
elseif conjA
18-
_diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA,
19-
wrap_stridedview(B.diag),
20-
pB, pAB, α,
21-
β)
16+
_diagtensorcontract!(SV(C), conj(SV(A)), pA, SV(B.diag), pB, pAB, α, β)
2217
elseif conjB
23-
_diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA,
24-
conj(wrap_stridedview(B.diag)),
25-
pB, pAB, α,
26-
β)
18+
_diagtensorcontract!(SV(C), SV(A), pA, conj(SV(B.diag)), pB, pAB, α, β)
2719
else
28-
_diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA,
29-
wrap_stridedview(B.diag), pB, pAB, α, β)
20+
_diagtensorcontract!(SV(C), SV(A), pA, SV(B.diag), pB, pAB, α, β)
3021
end
3122
return C
3223
end
@@ -50,17 +41,13 @@ function tensorcontract!(C::AbstractArray,
5041
TupleTools.getindices(indCinoBA, tpAB[2]))
5142

5243
if conjA && conjB
53-
_diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(B)), rpB,
54-
conj(wrap_stridedview(A.diag)), rpA, rpAB, α, β)
44+
_diagtensorcontract!(SV(C), conj(SV(B)), rpB, conj(SV(A.diag)), rpA, rpAB, α, β)
5545
elseif conjA
56-
_diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(B), rpB,
57-
conj(wrap_stridedview(A.diag)), rpA, rpAB, α, β)
46+
_diagtensorcontract!(SV(C), SV(B), rpB, conj(SV(A.diag)), rpA, rpAB, α, β)
5847
elseif conjB
59-
_diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(B)), rpB,
60-
wrap_stridedview(A.diag), rpA, rpAB, α, β)
48+
_diagtensorcontract!(SV(C), conj(SV(B)), rpB, SV(A.diag), rpA, rpAB, α, β)
6149
else
62-
_diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(B), rpB,
63-
wrap_stridedview(A.diag), rpA, rpAB, α, β)
50+
_diagtensorcontract!(SV(C), SV(B), rpB, SV(A.diag), rpA, rpAB, α, β)
6451
end
6552
return C
6653
end
@@ -75,17 +62,13 @@ function tensorcontract!(C::AbstractArray,
7562
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
7663

7764
if conjA && conjB
78-
_diagdiagcontract!(wrap_stridedview(C), conj(wrap_stridedview(A.diag)), pA,
79-
conj(wrap_stridedview(B.diag)), pB, pAB, α, β)
65+
_diagdiagcontract!(SV(C), conj(SV(A.diag)), pA, conj(SV(B.diag)), pB, pAB, α, β)
8066
elseif conjA
81-
_diagdiagcontract!(wrap_stridedview(C), conj(wrap_stridedview(A.diag)), pA,
82-
wrap_stridedview(B.diag), pB, pAB, α, β)
67+
_diagdiagcontract!(SV(C), conj(SV(A.diag)), pA, SV(B.diag), pB, pAB, α, β)
8368
elseif conjB
84-
_diagdiagcontract!(wrap_stridedview(C), wrap_stridedview(A.diag), pA,
85-
conj(wrap_stridedview(B.diag)), pB, pAB, α, β)
69+
_diagdiagcontract!(SV(C), SV(A.diag), pA, conj(SV(B.diag)), pB, pAB, α, β)
8670
else
87-
_diagdiagcontract!(wrap_stridedview(C), wrap_stridedview(A.diag), pA,
88-
wrap_stridedview(B.diag), pB, pAB, α, β)
71+
_diagdiagcontract!(SV(C), SV(A.diag), pA, SV(B.diag), pB, pAB, α, β)
8972
end
9073
return C
9174
end

src/implementation/strided.jl

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,16 @@ end
3838
#-------------------------------------------------------------------------------------------
3939
# Force strided implementation on AbstractArray instances with Strided backend
4040
#-------------------------------------------------------------------------------------------
41-
42-
# Wrap any compatible array into a `StridedView` for the implementation.
43-
# Additionally, we normalize the parent types to avoid having to have too many specializations.
44-
# This is allowed because we never return `parent(SV)`, so we can safely wrap anything
45-
# that represents the same data.
46-
wrap_stridedview(A::AbstractArray) = StridedView(A)
47-
@static if isdefined(Core, :Memory)
48-
# For Arrays: we simply use the memory directly
49-
# TODO: can we also do this for views?
50-
wrap_stridedview(A::Array) = StridedView(A.ref.mem, size(A), strides(A), 0, identity)
51-
end
52-
41+
const SV = StridedView
5342
function tensoradd!(C::AbstractArray,
5443
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
5544
α::Number, β::Number,
5645
backend::StridedBackend, allocator=DefaultAllocator())
5746
# resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on
5847
if conjA
59-
stridedtensoradd!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA, α, β, backend,
60-
allocator)
48+
stridedtensoradd!(SV(C), conj(SV(A)), pA, α, β, backend, allocator)
6149
else
62-
stridedtensoradd!(wrap_stridedview(C), wrap_stridedview(A), pA, α, β, backend,
63-
allocator)
50+
stridedtensoradd!(SV(C), SV(A), pA, α, β, backend, allocator)
6451
end
6552
return C
6653
end
@@ -71,11 +58,9 @@ function tensortrace!(C::AbstractArray,
7158
backend::StridedBackend, allocator=DefaultAllocator())
7259
# resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on
7360
if conjA
74-
stridedtensortrace!(wrap_stridedview(C), conj(wrap_stridedview(A)), p, q, α, β,
75-
backend, allocator)
61+
stridedtensortrace!(SV(C), conj(SV(A)), p, q, α, β, backend, allocator)
7662
else
77-
stridedtensortrace!(wrap_stridedview(C), wrap_stridedview(A), p, q, α, β, backend,
78-
allocator)
63+
stridedtensortrace!(SV(C), SV(A), p, q, α, β, backend, allocator)
7964
end
8065
return C
8166
end
@@ -88,20 +73,16 @@ function tensorcontract!(C::AbstractArray,
8873
backend::StridedBackend, allocator=DefaultAllocator())
8974
# resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on
9075
if conjA && conjB
91-
stridedtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA,
92-
conj(wrap_stridedview(B)), pB, pAB, α, β,
76+
stridedtensorcontract!(SV(C), conj(SV(A)), pA, conj(SV(B)), pB, pAB, α, β,
9377
backend, allocator)
9478
elseif conjA
95-
stridedtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA,
96-
wrap_stridedview(B), pB, pAB, α, β,
79+
stridedtensorcontract!(SV(C), conj(SV(A)), pA, SV(B), pB, pAB, α, β,
9780
backend, allocator)
9881
elseif conjB
99-
stridedtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA,
100-
conj(wrap_stridedview(B)), pB, pAB, α, β,
82+
stridedtensorcontract!(SV(C), SV(A), pA, conj(SV(B)), pB, pAB, α, β,
10183
backend, allocator)
10284
else
103-
stridedtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA,
104-
wrap_stridedview(B), pB, pAB, α, β,
85+
stridedtensorcontract!(SV(C), SV(A), pA, SV(B), pB, pAB, α, β,
10586
backend, allocator)
10687
end
10788
return C
@@ -149,7 +130,7 @@ function stridedtensortrace!(C::StridedView,
149130
newstrides = (strideA.(linearize(p))..., (strideA.(q[1]) .+ strideA.(q[2]))...)
150131
newsize = (size(C)..., tracesize...)
151132

152-
A′ = StridedView(A.parent, newsize, newstrides, A.offset, A.op)
133+
A′ = SV(A.parent, newsize, newstrides, A.offset, A.op)
153134
Strided._mapreducedim!(Scaler(α), Adder(), Scaler(β), newsize, (C, A′))
154135
return C
155136
end

0 commit comments

Comments
 (0)