3838# -------------------------------------------------------------------------------------------
3939# Force strided implementation on AbstractArray instances with Strided backend
4040# -------------------------------------------------------------------------------------------
41- const SV = StridedView
41+
42+ # we normalize the parent types here to avoid too many specializations
43+ # this is allowed because we never return `parent(SV)`, so we can safely wrap anything
44+ # that represents the same data
45+ """
46+ wrap_stridedview(A::AbstractArray)
47+
48+ Wrap any compatible array into a `StridedView` for the implementation.
49+ Additionally, we normalize the parent types to avoid having to have too many specializations.
50+ This is allowed because we never return `parent(SV)`, so we can safely wrap anything
51+ that represents the same data.
52+ """
53+ wrap_stridedview (A:: AbstractArray ) = StridedView (reshape (A, length (A)),
54+ size (A), strides (A), 0 , identity)
55+ wrap_stridedview (A:: StridedView ) = A
56+ @static if isdefined (Core, :Memory )
57+ # For Arrays: we simply use the memory directly
58+ # TODO : can we also do this for views?
59+ wrap_stridedview (A:: Array ) = StridedView (A. ref. mem, size (A), strides (A), 0 , identity)
60+ end
61+
4262function tensoradd! (C:: AbstractArray ,
4363 A:: AbstractArray , pA:: Index2Tuple , conjA:: Bool ,
4464 α:: Number , β:: Number ,
4565 backend:: StridedBackend , allocator= DefaultAllocator ())
4666 # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on
4767 if conjA
48- stridedtensoradd! (SV (C), conj (SV (A)), pA, α, β, backend, allocator)
68+ stridedtensoradd! (wrap_stridedview (C), conj (wrap_stridedview (A)), pA, α, β, backend,
69+ allocator)
4970 else
50- stridedtensoradd! (SV (C), SV (A), pA, α, β, backend, allocator)
71+ stridedtensoradd! (wrap_stridedview (C), wrap_stridedview (A), pA, α, β, backend,
72+ allocator)
5173 end
5274 return C
5375end
@@ -58,9 +80,11 @@ function tensortrace!(C::AbstractArray,
5880 backend:: StridedBackend , allocator= DefaultAllocator ())
5981 # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on
6082 if conjA
61- stridedtensortrace! (SV (C), conj (SV (A)), p, q, α, β, backend, allocator)
83+ stridedtensortrace! (wrap_stridedview (C), conj (wrap_stridedview (A)), p, q, α, β,
84+ backend, allocator)
6285 else
63- stridedtensortrace! (SV (C), SV (A), p, q, α, β, backend, allocator)
86+ stridedtensortrace! (wrap_stridedview (C), wrap_stridedview (A), p, q, α, β, backend,
87+ allocator)
6488 end
6589 return C
6690end
@@ -73,16 +97,20 @@ function tensorcontract!(C::AbstractArray,
7397 backend:: StridedBackend , allocator= DefaultAllocator ())
7498 # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on
7599 if conjA && conjB
76- stridedtensorcontract! (SV (C), conj (SV (A)), pA, conj (SV (B)), pB, pAB, α, β,
100+ stridedtensorcontract! (wrap_stridedview (C), conj (wrap_stridedview (A)), pA,
101+ conj (wrap_stridedview (B)), pB, pAB, α, β,
77102 backend, allocator)
78103 elseif conjA
79- stridedtensorcontract! (SV (C), conj (SV (A)), pA, SV (B), pB, pAB, α, β,
104+ stridedtensorcontract! (wrap_stridedview (C), conj (wrap_stridedview (A)), pA,
105+ wrap_stridedview (B), pB, pAB, α, β,
80106 backend, allocator)
81107 elseif conjB
82- stridedtensorcontract! (SV (C), SV (A), pA, conj (SV (B)), pB, pAB, α, β,
108+ stridedtensorcontract! (wrap_stridedview (C), wrap_stridedview (A), pA,
109+ conj (wrap_stridedview (B)), pB, pAB, α, β,
83110 backend, allocator)
84111 else
85- stridedtensorcontract! (SV (C), SV (A), pA, SV (B), pB, pAB, α, β,
112+ stridedtensorcontract! (wrap_stridedview (C), wrap_stridedview (A), pA,
113+ wrap_stridedview (B), pB, pAB, α, β,
86114 backend, allocator)
87115 end
88116 return C
@@ -130,7 +158,7 @@ function stridedtensortrace!(C::StridedView,
130158 newstrides = (strideA .(linearize (p))... , (strideA .(q[1 ]) .+ strideA .(q[2 ])). .. )
131159 newsize = (size (C)... , tracesize... )
132160
133- A′ = SV (A. parent, newsize, newstrides, A. offset, A. op)
161+ A′ = StridedView (A. parent, newsize, newstrides, A. offset, A. op)
134162 Strided. _mapreducedim! (Scaler (α), Adder (), Scaler (β), newsize, (C, A′))
135163 return C
136164end
0 commit comments