Skip to content

Commit b271ffc

Browse files
authored
Add permutedims/factorize overrides (#64)
* Add permutedims/factorize overrides * import permutedims
1 parent 492fa27 commit b271ffc

File tree

6 files changed

+59
-7
lines changed

6 files changed

+59
-7
lines changed

src/LazyArrays.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import Base: AbstractArray, AbstractMatrix, AbstractVector,
1717
getindex, setindex!, intersect, @_inline_meta, inv,
1818
sort, sort!, issorted, sortperm, diff, cumsum, sum, in, broadcast,
1919
eltype, parent, real, imag,
20-
conj, transpose, adjoint, vec,
20+
conj, transpose, adjoint, permutedims, vec,
2121
exp, log, sqrt, cos, sin, tan, csc, sec, cot,
2222
cosh, sinh, tanh, csch, sech, coth,
2323
acos, asin, atan, acsc, asec, acot,
@@ -36,7 +36,7 @@ import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcas
3636
combine_eltypes, DefaultArrayStyle, instantiate, materialize,
3737
materialize!, eltypes
3838

39-
import LinearAlgebra: AbstractTriangular, AbstractQ, checksquare, pinv, fill!, tilebufsize, Abuf, Bbuf, Cbuf, dot
39+
import LinearAlgebra: AbstractTriangular, AbstractQ, checksquare, pinv, fill!, tilebufsize, Abuf, Bbuf, Cbuf, dot, factorize, qr, lu, cholesky
4040

4141
import LinearAlgebra.BLAS: BlasFloat, BlasReal, BlasComplex
4242

src/lazyapplying.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,10 @@ end
264264
@inline lazy_getindex(A, I...) = sub_materialize(view(A, I...))
265265

266266

267-
@inline getindex(A::ApplyMatrix, kr::Colon, jr::Colon) = lazy_getindex(A, kr, jr)
268-
@inline getindex(A::ApplyMatrix, kr::Colon, jr::AbstractUnitRange) = lazy_getindex(A, kr, jr)
269-
@inline getindex(A::ApplyMatrix, kr::AbstractUnitRange, jr::Colon) = lazy_getindex(A, kr, jr)
270-
@inline getindex(A::ApplyMatrix, kr::AbstractUnitRange, jr::AbstractUnitRange) = lazy_getindex(A, kr, jr)
267+
@inline getindex(A::LazyMatrix, kr::Colon, jr::Colon) = lazy_getindex(A, kr, jr)
268+
@inline getindex(A::LazyMatrix, kr::Colon, jr::AbstractUnitRange) = lazy_getindex(A, kr, jr)
269+
@inline getindex(A::LazyMatrix, kr::AbstractUnitRange, jr::Colon) = lazy_getindex(A, kr, jr)
270+
@inline getindex(A::LazyMatrix, kr::AbstractUnitRange, jr::AbstractUnitRange) = lazy_getindex(A, kr, jr)
271271

272272

273273
diagonallayout(::LazyLayout) = DiagonalLayout{LazyLayout}()

src/lazybroadcasting.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,10 @@ _broadcast_rowsupport(sz, A::AbstractMatrix, j) = size(A,2) == 1 ? OneTo(sz[2])
164164

165165
colsupport(::BroadcastLayout{typeof(*)}, A, j) = intersect(_broadcast_colsupport.(Ref(size(A)), A.args, j)...)
166166
rowsupport(::BroadcastLayout{typeof(*)}, A, j) = intersect(_broadcast_rowsupport.(Ref(size(A)), A.args, j)...)
167+
168+
for op in (:+, :-)
169+
@eval begin
170+
colsupport(::BroadcastLayout{typeof($op)}, A, j) = convexunion(_broadcast_colsupport.(Ref(size(A)), A.args, j)...)
171+
rowsupport(::BroadcastLayout{typeof($op)}, A, j) = convexunion(_broadcast_rowsupport.(Ref(size(A)), A.args, j)...)
172+
end
173+
end

src/lazyconcat.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,27 @@ end
118118
# based on Base/array.jl, Base/abstractarray.jl
119119

120120
function copyto!(dest::AbstractMatrix, V::Vcat{<:Any,2})
121+
arrays = V.args
122+
nargs = length(arrays)
123+
nrows = size(dest,1)
124+
nrows == sum(a->size(a, 1), arrays) || throw(DimensionMismatch("sum of rows each matrix must equal $nrows"))
125+
ncols = size(dest, 2)
126+
for a in arrays
127+
if size(a, 2) != ncols
128+
throw(DimensionMismatch("number of columns of each array must match (got $(map(x->size(x,2), A)))"))
129+
end
130+
end
131+
pos = 1
132+
for a in arrays
133+
p1 = pos+size(a,1)-1
134+
dest[pos:p1, :] .= a
135+
pos = p1+1
136+
end
137+
return dest
138+
end
139+
140+
# this is repeated to avoid allocation in .=
141+
function copyto!(dest::AbstractMatrix, V::Vcat{<:Any,2,<:Tuple{Vararg{<:AbstractMatrix}}})
121142
arrays = V.args
122143
nargs = length(arrays)
123144
nrows = size(dest,1)
@@ -259,6 +280,12 @@ _vec(a::AbstractArray) = vec(a)
259280
_vec(a::Adjoint{<:Number,<:AbstractVector}) = _vec(parent(a))
260281
vec(A::Hcat) = Vcat(_vec.(A.args)...)
261282

283+
_permutedims(a) = a
284+
_permutedims(a::AbstractArray) = permutedims(a)
285+
286+
permutedims(A::Hcat{T}) where T = Vcat{T}(map(_permutedims,A.args)...)
287+
permutedims(A::Vcat{T}) where T = Hcat{T}(map(_permutedims,A.args)...)
288+
262289

263290
#####
264291
# broadcasting

src/linalg/factorizations.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,17 @@ function copyto!(dest::AbstractArray, M::Ldiv{QLayout})
3333
materialize!(Ldiv(A,dest))
3434
end
3535

36-
materialize!(M::Ldiv{QLayout}) = materialize!(Lmul(M.A',M.B))
36+
materialize!(M::Ldiv{QLayout}) = materialize!(Lmul(M.A',M.B))
37+
38+
factorizestyle(_) = DefaultArrayApplyStyle()
39+
40+
for op in (:factorize, :qr, :lu, :cholesky)
41+
@eval begin
42+
$op(B::LazyMatrix) = apply($op, B)
43+
ApplyStyle(::typeof($op), B::Type{<:AbstractMatrix}) = factorizestyle(MemoryLayout(B))
44+
materialize(A::Applied{DefaultArrayApplyStyle,typeof($op),<:Tuple{<:AbstractMatrix{T}}}) where T =
45+
Base.invoke($op, Tuple{AbstractMatrix{T}}, A.args...)
46+
47+
eltype(A::Applied{<:Any,typeof($op)}) = float(eltype(first(A.args)))
48+
end
49+
end

test/concattests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import LazyArrays: MemoryLayout, DenseColumnMajor, PaddedLayout, materialize!, M
1919
@test copy(A) !== A
2020
@test vec(A) === A
2121
@test A' == transpose(A) == Vector(A)'
22+
@test permutedims(A) == permutedims(Vector(A))
2223

2324
A = @inferred(Vcat(1:10, 1:20))
2425
@test @inferred(length(A)) == 30
@@ -35,6 +36,7 @@ import LazyArrays: MemoryLayout, DenseColumnMajor, PaddedLayout, materialize!, M
3536
@test A' == transpose(A) == Vector(A)'
3637
@test A' === Hcat((1:10)', (1:20)')
3738
@test transpose(A) === Hcat(transpose(1:10), transpose(1:20))
39+
@test permutedims(A) == permutedims(Vector(A))
3840

3941
A = Vcat(randn(2,10), randn(4,10))
4042
@test @inferred(length(A)) == 60
@@ -51,6 +53,7 @@ import LazyArrays: MemoryLayout, DenseColumnMajor, PaddedLayout, materialize!, M
5153
@test copy(A) !== A
5254
@test vec(A) == vec(Matrix(A))
5355
@test A' == transpose(A) == Matrix(A)'
56+
@test permutedims(A) == permutedims(Matrix(A))
5457

5558
A = Vcat(randn(2,10).+im.*randn(2,10), randn(4,10).+im.*randn(4,10))
5659
@test eltype(A) == ComplexF64
@@ -69,6 +72,7 @@ import LazyArrays: MemoryLayout, DenseColumnMajor, PaddedLayout, materialize!, M
6972
@test vec(A) == vec(Matrix(A))
7073
@test A' == Matrix(A)'
7174
@test transpose(A) == transpose(Matrix(A))
75+
@test permutedims(A) == permutedims(Matrix(A))
7276

7377
@test Vcat() isa Vcat{Any,1,Tuple{}}
7478

@@ -77,6 +81,7 @@ import LazyArrays: MemoryLayout, DenseColumnMajor, PaddedLayout, materialize!, M
7781
@test A[1,1] == 1.0
7882
@test A[2,1] == 0.0
7983
@test axes(A) == (Base.OneTo(4),Base.OneTo(1))
84+
@test permutedims(A) == permutedims(Matrix(A))
8085
end
8186
@testset "Hcat" begin
8287
A = @inferred(Hcat(1:10, 2:11))

0 commit comments

Comments
 (0)