9191# -------------------------------------------------------------------------------------------
9292# StridedView implementation
9393# -------------------------------------------------------------------------------------------
94+ struct Adder end
95+ (:: Adder )(x, y) = VectorInterface. add (x, y)
96+ struct Scaler{T}
97+ α:: T
98+ end
99+ (s:: Scaler )(x) = scale (x, s. α)
100+ (s:: Scaler )(x, y) = scale (x * y, s. α)
101+
94102function stridedtensoradd! (C:: StridedView ,
95103 A:: StridedView , pA:: Index2Tuple ,
96104 α:: Number , β:: Number ,
@@ -102,9 +110,7 @@ function stridedtensoradd!(C::StridedView,
102110 end
103111
104112 A′ = permutedims (A, linearize (pA))
105- op1 = Base. Fix2 (scale, α)
106- op2 = Base. Fix2 (scale, β)
107- Strided. _mapreducedim! (op1, + , op2, size (C), (C, A′))
113+ Strided. _mapreducedim! (Scaler (α), Adder (), Scaler (β), size (C), (C, A′))
108114 return C
109115end
110116
@@ -125,9 +131,7 @@ function stridedtensortrace!(C::StridedView,
125131 newsize = (size (C)... , tracesize... )
126132
127133 A′ = SV (A. parent, newsize, newstrides, A. offset, A. op)
128- op1 = Base. Fix2 (scale, α)
129- op2 = Base. Fix2 (scale, β)
130- Strided. _mapreducedim! (op1, + , op2, newsize, (C, A′))
134+ Strided. _mapreducedim! (Scaler (α), Adder (), Scaler (β), newsize, (C, A′))
131135 return C
132136end
133137
@@ -170,8 +174,6 @@ function stridedtensorcontract!(C::StridedView,
170174 (osizeA... , osizeB... , one .(csizeA)... ))
171175 tsize = (osizeA... , osizeB... , csizeA... )
172176
173- op1 = Base. Fix2 (scale, α) ∘ *
174- op2 = Base. Fix2 (scale, β)
175- Strided. _mapreducedim! (op1, + , op2, tsize, (CS, AS, BS))
177+ Strided. _mapreducedim! (Scaler (α), Adder (), Scaler (β), tsize, (CS, AS, BS))
176178 return C
177179end
0 commit comments