Skip to content

Commit da806d3

Browse files
Merge pull request #92 from vpuri3/reshape
using `Base.ReshapedArray` is unnecessary
2 parents c69db67 + 4d413f6 commit da806d3

File tree

7 files changed

+62
-76
lines changed

7 files changed

+62
-76
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLOperators"
22
uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
33
authors = ["xtalax <[email protected]>"]
4-
version = "0.1.8"
4+
version = "0.1.9"
55

66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/batch.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
5050
if isreal(L)
5151
true
5252
else
53-
d = _vec(L.diag)
53+
d = vec(L.diag)
5454
D = Diagonal(d)
5555
ishermitian(d)
5656
end
5757
end
58-
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(_vec(L.diag)))
58+
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))
5959

6060
isconstant(L::BatchedDiagonalOperator) = L.update_func == DEFAULT_UPDATE_FUNC
6161
issquare(L::BatchedDiagonalOperator) = true
@@ -72,38 +72,38 @@ Base.:*(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .* u
7272
Base.:\(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .\ u
7373

7474
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::BatchedDiagonalOperator, u::AbstractVecOrMat)
75-
V = _vec(v)
76-
U = _vec(u)
77-
d = _vec(L.diag)
75+
V = vec(v)
76+
U = vec(u)
77+
d = vec(L.diag)
7878
D = Diagonal(d)
7979
mul!(V, D, U)
8080

8181
v
8282
end
8383

8484
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::BatchedDiagonalOperator, u::AbstractVecOrMat, α, β)
85-
V = _vec(v)
86-
U = _vec(u)
87-
d = _vec(L.diag)
85+
V = vec(v)
86+
U = vec(u)
87+
d = vec(L.diag)
8888
D = Diagonal(d)
8989
mul!(V, D, U, α, β)
9090

9191
v
9292
end
9393

9494
function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::BatchedDiagonalOperator, u::AbstractVecOrMat)
95-
V = _vec(v)
96-
U = _vec(u)
97-
d = _vec(L.diag)
95+
V = vec(v)
96+
U = vec(u)
97+
d = vec(L.diag)
9898
D = Diagonal(d)
9999
ldiv!(V, D, U)
100100

101101
v
102102
end
103103

104104
function LinearAlgebra.ldiv!(L::BatchedDiagonalOperator, u::AbstractVecOrMat)
105-
U = _vec(u)
106-
d = _vec(L.diag)
105+
U = vec(u)
106+
d = vec(L.diag)
107107
D = Diagonal(d)
108108
ldiv!(D, U)
109109

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function cache_operator(L::AbstractSciMLOperator, u::AbstractArray)
5757

5858
@assert s[1] == n "Dimension mismatch"
5959

60-
U = _reshape(u, (n, k))
60+
U = reshape(u, (n, k))
6161
L = cache_operator(L, U)
6262
L
6363
end

src/left.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ for (op, LType, VType) in (
9494
)
9595

9696
@eval function cache_internals(L::$LType, u::AbstractVecOrMat)
97-
@set! L.L = cache_operator(L.L, _reshape(u, size(L,1)))
97+
@set! L.L = cache_operator(L.L, reshape(u, size(L,1)))
9898
L
9999
end
100100

src/multidim.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ for op in (
1414
(size(L, 1), size(u)[2:end]...,)
1515
end
1616

17-
uu = _reshape(u, sizes[1])
17+
uu = reshape(u, sizes[1])
1818
vv = $op(L, uu)
1919

20-
_reshape(vv, sizev)
20+
reshape(vv, sizev)
2121
end
2222
end
2323

@@ -26,8 +26,8 @@ function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLLinearOperator, u:
2626

2727
sizes = _mat_sizes(L, u)
2828

29-
uu = _reshape(u, sizes[1])
30-
vv = _reshape(v, sizes[2])
29+
uu = reshape(u, sizes[1])
30+
vv = reshape(v, sizes[2])
3131

3232
mul!(vv, L, uu)
3333

@@ -39,8 +39,8 @@ function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLLinearOperator, u:
3939

4040
sizes = _mat_sizes(L, u)
4141

42-
uu = _reshape(u, sizes[1])
43-
vv = _reshape(v, sizes[2])
42+
uu = reshape(u, sizes[1])
43+
vv = reshape(v, sizes[2])
4444

4545
mul!(vv, L, uu, α, β)
4646

@@ -52,8 +52,8 @@ function LinearAlgebra.ldiv!(v::AbstractArray, L::AbstractSciMLLinearOperator, u
5252

5353
sizes = _mat_sizes(L, u)
5454

55-
uu = _reshape(u, sizes[1])
56-
vv = _reshape(v, sizes[2])
55+
uu = reshape(u, sizes[1])
56+
vv = reshape(v, sizes[2])
5757

5858
ldiv!(vv, L, uu)
5959

@@ -65,7 +65,7 @@ function LinearAlgebra.ldiv!(L::AbstractSciMLLinearOperator, u::AbstractArray)
6565

6666
sizes = _mat_sizes(L, u)
6767

68-
uu = _reshape(u, sizes[1])
68+
uu = reshape(u, sizes[1])
6969

7070
ldiv!(L, uu)
7171

src/tensor.jl

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ function Base.:*(L::TensorProductOperator, u::AbstractVecOrMat)
9696
m , n = size(L)
9797
k = size(u, 2)
9898

99-
U = _reshape(u, (ni, no*k))
99+
U = reshape(u, (ni, no*k))
100100
C = L.inner * U
101101

102102
V = outer_mul(L, u, C)
103103

104-
u isa AbstractMatrix ? _reshape(V, (m, k)) : _reshape(V, (m,))
104+
u isa AbstractMatrix ? reshape(V, (m, k)) : reshape(V, (m,))
105105
end
106106

107107
function Base.:\(L::TensorProductOperator, u::AbstractVecOrMat)
@@ -110,12 +110,12 @@ function Base.:\(L::TensorProductOperator, u::AbstractVecOrMat)
110110
m , n = size(L)
111111
k = size(u, 2)
112112

113-
U = _reshape(u, (ni, no*k))
113+
U = reshape(u, (ni, no*k))
114114
C = L.inner \ U
115115

116116
V = outer_div(L, u, C)
117117

118-
u isa AbstractMatrix ? _reshape(V, (m, k)) : _reshape(V, (m,))
118+
u isa AbstractMatrix ? reshape(V, (m, k)) : reshape(V, (m,))
119119
end
120120

121121
function cache_self(L::TensorProductOperator, u::AbstractVecOrMat)
@@ -141,7 +141,7 @@ function cache_internals(L::TensorProductOperator, u::AbstractVecOrMat) where{D}
141141
_ , no = size(L.outer)
142142
k = size(u, 2)
143143

144-
uinner = _reshape(u, (ni, no*k))
144+
uinner = reshape(u, (ni, no*k))
145145
uouter = L.cache[2]
146146

147147
@set! L.inner = cache_operator(L.inner, uinner)
@@ -158,7 +158,7 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
158158
k = size(u, 2)
159159

160160
C1, C2, C3, _ = L.cache
161-
U = _reshape(u, (ni, no*k))
161+
U = reshape(u, (ni, no*k))
162162

163163
"""
164164
v .= kron(B, A) * u
@@ -183,7 +183,7 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
183183
k = size(u, 2)
184184

185185
C1, C2, C3, c4 = L.cache
186-
U = _reshape(u, (ni, no*k))
186+
U = reshape(u, (ni, no*k))
187187

188188
"""
189189
v .= α * kron(B, A) * u + β * v
@@ -208,7 +208,7 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::TensorProductOperator, u::A
208208
k = size(u, 2)
209209

210210
C1, C2, C3, _ = L.cache
211-
U = _reshape(u, (ni, no*k))
211+
U = reshape(u, (ni, no*k))
212212

213213
"""
214214
v .= kron(B, A) ldiv u
@@ -232,7 +232,7 @@ function LinearAlgebra.ldiv!(L::TensorProductOperator, u::AbstractVecOrMat)
232232
no = size(L.outer, 1)
233233
k = size(u, 2)
234234

235-
U = _reshape(u, (ni, no*k))
235+
U = reshape(u, (ni, no*k))
236236

237237
"""
238238
u .= kron(B, A) ldiv u
@@ -268,12 +268,12 @@ function outer_mul(L::TensorProductOperator, u::AbstractVecOrMat, C::AbstractVec
268268
mo, no = size(L.outer)
269269
# m , n = size(L)
270270

271-
C = _reshape(C, (mi, no, k))
271+
C = reshape(C, (mi, no, k))
272272
C = permutedims(C, PERM)
273-
C = _reshape(C, (no, mi*k))
273+
C = reshape(C, (no, mi*k))
274274

275275
V = L.outer * C
276-
V = _reshape(V, (mo, mi, k))
276+
V = reshape(V, (mo, mi, k))
277277
V = permutedims(V, PERM)
278278

279279
V
@@ -297,20 +297,20 @@ function outer_mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::AbstractVe
297297
k = size(u, 2)
298298

299299
if k == 1
300-
V = _reshape(v, (mi, mo))
301-
C1 = _reshape(C1, (mi, no))
300+
V = reshape(v, (mi, mo))
301+
C1 = reshape(C1, (mi, no))
302302
mul!(transpose(V), L.outer, transpose(C1))
303303
return v
304304
end
305305

306306
_, C2, C3, _ = L.cache
307307

308-
C1 = _reshape(C1, (mi, no, k))
308+
C1 = reshape(C1, (mi, no, k))
309309
permutedims!(C2, C1, PERM)
310-
C2 = _reshape(C2, (no, mi*k))
310+
C2 = reshape(C2, (no, mi*k))
311311
mul!(C3, L.outer, C2)
312-
C3 = _reshape(C3, (mo, mi, k))
313-
V = _reshape(v , (mi, mo, k))
312+
C3 = reshape(C3, (mo, mi, k))
313+
V = reshape(v , (mi, mo, k))
314314
permutedims!(V, C3, PERM)
315315

316316
v
@@ -322,7 +322,7 @@ function outer_mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::AbstractVe
322322
m, _ = size(L)
323323

324324
if L.outer isa IdentityOperator
325-
c1 = _reshape(C1, (m, k))
325+
c1 = reshape(C1, (m, k))
326326
axpby!(α, c1, β, v)
327327
return v
328328
elseif L.outer isa ScaledOperator
@@ -336,20 +336,20 @@ function outer_mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::AbstractVe
336336
k = size(u, 2)
337337

338338
if k == 1
339-
V = _reshape(v, (mi, mo))
340-
C1 = _reshape(C1, (mi, no))
339+
V = reshape(v, (mi, mo))
340+
C1 = reshape(C1, (mi, no))
341341
mul!(transpose(V), L.outer, transpose(C1), α, β)
342342
return v
343343
end
344344

345345
_, C2, C3, c4 = L.cache
346346

347-
C1 = _reshape(C1, (mi, no, k))
347+
C1 = reshape(C1, (mi, no, k))
348348
permutedims!(C2, C1, PERM)
349-
C2 = _reshape(C2, (no, mi*k))
349+
C2 = reshape(C2, (no, mi*k))
350350
mul!(C3, L.outer, C2)
351-
C3 = _reshape(C3, (mo, mi, k))
352-
V = _reshape(v , (mi, mo, k))
351+
C3 = reshape(C3, (mo, mi, k))
352+
V = reshape(v , (mi, mo, k))
353353
copy!(c4, v)
354354
permutedims!(V, C3, PERM)
355355
axpby!(β, c4, α, v)
@@ -373,12 +373,12 @@ function outer_div(L::TensorProductOperator, u::AbstractVecOrMat, C::AbstractVec
373373
mo, no = size(L.outer)
374374
# m , n = size(L)
375375

376-
C = _reshape(C, (mi, no, k))
376+
C = reshape(C, (mi, no, k))
377377
C = permutedims(C, PERM)
378-
C = _reshape(C, (no, mi*k))
378+
C = reshape(C, (no, mi*k))
379379

380380
V = L.outer \ C
381-
V = _reshape(V, (mo, mi, k))
381+
V = reshape(V, (mo, mi, k))
382382
V = permutedims(V, PERM)
383383

384384
V
@@ -402,20 +402,20 @@ function outer_div!(v::AbstractVecOrMat, L::TensorProductOperator, u::AbstractVe
402402
k = size(u, 2)
403403

404404
if k == 1
405-
V = _reshape(v, (mi, mo))
406-
C1 = _reshape(C1, (mi, no))
405+
V = reshape(v, (mi, mo))
406+
C1 = reshape(C1, (mi, no))
407407
ldiv!(transpose(V), L.outer, transpose(C1))
408408
return v
409409
end
410410

411411
_, C2, C3, _ = L.cache
412412

413-
C1 = _reshape(C1, (mi, no, k))
413+
C1 = reshape(C1, (mi, no, k))
414414
permutedims!(C2, C1, PERM)
415-
C2 = _reshape(C2, (no, mi*k))
415+
C2 = reshape(C2, (no, mi*k))
416416
ldiv!(C3, L.outer, C2)
417-
C3 = _reshape(C3, (mo, mi, k))
418-
V = _reshape(v , (mi, mo, k))
417+
C3 = reshape(C3, (mo, mi, k))
418+
V = reshape(v , (mi, mo, k))
419419
permutedims!(V, C3, PERM)
420420

421421
v
@@ -435,7 +435,7 @@ function outer_div!(L::TensorProductOperator, u::AbstractVecOrMat)
435435
# m , n = size(L)
436436
k = size(u, 2)
437437

438-
U = _reshape(u, (ni, no*k))
438+
U = reshape(u, (ni, no*k))
439439

440440
if k == 1
441441
ldiv!(L.outer, transpose(U))
@@ -444,8 +444,8 @@ function outer_div!(L::TensorProductOperator, u::AbstractVecOrMat)
444444

445445
C = first(L.cache)
446446

447-
U = _reshape(U, (ni, no, k))
448-
C = _reshape(C, (no, ni, k))
447+
U = reshape(U, (ni, no, k))
448+
C = reshape(C, (no, ni, k))
449449
permutedims!(C, U, PERM)
450450
ldiv!(L.outer, C)
451451
permutedims!(U, C, PERM)

src/utils.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,4 @@
11
#
2-
""" use Base.ReshapedArray """
3-
_reshape(a, dims::NTuple{D,Int}) where{D} = reshape(a,dims)
4-
_reshape(a::ReshapedArray, dims::NTuple{D,Int}) where{D} = _reshape(a.parent, dims)
5-
6-
function _reshape(a::AbstractArray, dims::NTuple{D,Int}) where{D}
7-
@assert prod(dims) == length(a) "cannot reshape array of size $(size(a)) to size $dims"
8-
dims == size(a) && return a
9-
ReshapedArray(a, dims, ())
10-
end
11-
12-
_vec(a) = vec(a)
13-
_vec(a::AbstractVector) = a
14-
_vec(a::AbstractArray) = _reshape(a,(length(a),))
15-
162
function _mat_sizes(L::AbstractSciMLOperator, u::AbstractArray)
173
m, n = size(L)
184
nk = length(u)

0 commit comments

Comments
 (0)