Skip to content

Commit 81f8ff8

Browse files
authored
Support adjoint strides (#13)
* Support adjoint strides * v0.2.3 * Remove MemoryLayout(typeof(...)) usage * Update test_muladd.jl * Random in tests
1 parent 68098b6 commit 81f8ff8

File tree

8 files changed

+131
-107
lines changed

8 files changed

+131
-107
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -13,6 +13,7 @@ julia = "1"
1313

1414
[extras]
1515
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
16+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617

1718
[targets]
18-
test = ["Test"]
19+
test = ["Test", "Random"]

src/ArrayLayouts.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ import Base: AbstractArray, AbstractMatrix, AbstractVector,
2626
AbstractArray, AbstractVector, axes, (:), _sub2ind_recurse, broadcast, promote_eltypeof,
2727
similar, @_gc_preserve_end, @_gc_preserve_begin,
2828
@nexprs, @ncall, @ntuple, tuple_type_tail,
29-
all, any, isbitsunion, issubset, replace_in_print_matrix, replace_with_centered_mark
29+
all, any, isbitsunion, issubset, replace_in_print_matrix, replace_with_centered_mark,
30+
strides, unsafe_convert
3031

3132
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcasted,
3233
combine_eltypes, DefaultArrayStyle, instantiate, materialize,
@@ -64,6 +65,15 @@ abstract type LayoutArray{T,N} <: AbstractArray{T,N} end
6465
const LayoutMatrix{T} = LayoutArray{T,2}
6566
const LayoutVector{T} = LayoutArray{T,1}
6667

68+
## TODO: Following are type piracy whch may be removed in Julia v1.5
69+
_transpose_strides(a) = (a,1)
70+
_transpose_strides(a,b) = (b,a)
71+
strides(A::Adjoint) = _transpose_strides(strides(parent(A))...)
72+
strides(A::Transpose) = _transpose_strides(strides(parent(A))...)
73+
74+
unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real}) where T<:Real = unsafe_convert(Ptr{T}, parent(A))
75+
unsafe_convert(::Type{Ptr{T}}, A::Transpose) where T = unsafe_convert(Ptr{T}, parent(A))
76+
6777
include("memorylayout.jl")
6878
include("muladd.jl")
6979
include("lmul.jl")
@@ -74,7 +84,7 @@ include("factorizations.jl")
7484

7585
@inline sub_materialize(_, V, _) = Array(V)
7686
@inline sub_materialize(L, V) = sub_materialize(L, V, axes(V))
77-
@inline sub_materialize(V::SubArray) = sub_materialize(MemoryLayout(typeof(V)), V)
87+
@inline sub_materialize(V::SubArray) = sub_materialize(MemoryLayout(V), V)
7888

7989
@inline layout_getindex(A, I...) = sub_materialize(view(A, I...))
8090

@@ -104,22 +114,22 @@ _copyto!(_, _, dest::AbstractArray{T,N}, src::AbstractArray{V,N}) where {T,V,N}
104114

105115

106116
copyto!(dest::LayoutArray{<:Any,N}, src::LayoutArray{<:Any,N}) where N =
107-
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
117+
_copyto!(MemoryLayout(dest), MemoryLayout(src), dest, src)
108118
copyto!(dest::AbstractArray{<:Any,N}, src::LayoutArray{<:Any,N}) where N =
109-
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
119+
_copyto!(MemoryLayout(dest), MemoryLayout(src), dest, src)
110120
copyto!(dest::LayoutArray{<:Any,N}, src::AbstractArray{<:Any,N}) where N =
111-
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
121+
_copyto!(MemoryLayout(dest), MemoryLayout(src), dest, src)
112122

113123
copyto!(dest::SubArray{<:Any,N,<:LayoutArray}, src::SubArray{<:Any,N,<:LayoutArray}) where N =
114-
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
124+
_copyto!(MemoryLayout(dest), MemoryLayout(src), dest, src)
115125
copyto!(dest::SubArray{<:Any,N,<:LayoutArray}, src::LayoutArray{<:Any,N}) where N =
116-
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
126+
_copyto!(MemoryLayout(dest), MemoryLayout(src), dest, src)
117127
copyto!(dest::LayoutArray{<:Any,N}, src::SubArray{<:Any,N,<:LayoutArray}) where N =
118-
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
128+
_copyto!(MemoryLayout(dest), MemoryLayout(src), dest, src)
119129
copyto!(dest::SubArray{<:Any,N,<:LayoutArray}, src::AbstractArray{<:Any,N}) where N =
120-
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
130+
_copyto!(MemoryLayout(dest), MemoryLayout(src), dest, src)
121131
copyto!(dest::AbstractArray{<:Any,N}, src::SubArray{<:Any,N,<:LayoutArray}) where N =
122-
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
132+
_copyto!(MemoryLayout(dest), MemoryLayout(src), dest, src)
123133

124134

125135

src/factorizations.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ _factorize(layout, axes, A) = Base.invoke(factorize, Tuple{AbstractMatrix{eltype
3636

3737
macro _layoutfactorizations(Typ)
3838
esc(quote
39-
LinearAlgebra.qr(A::$Typ, args...; kwds...) = ArrayLayouts._qr(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A, args...; kwds...)
40-
LinearAlgebra.qr!(A::$Typ, args...; kwds...) = ArrayLayouts._qr!(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A, args...; kwds...)
41-
LinearAlgebra.lu(A::$Typ, pivot::Union{Val{false}, Val{true}}; kwds...) = ArrayLayouts._lu(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A, pivot; kwds...)
42-
LinearAlgebra.lu(A::$Typ{T}; kwds...) where T = ArrayLayouts._lu(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A; kwds...)
43-
LinearAlgebra.lu!(A::$Typ, args...; kwds...) = ArrayLayouts._lu!(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A, args...; kwds...)
44-
LinearAlgebra.factorize(A::$Typ) = ArrayLayouts._factorize(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A)
39+
LinearAlgebra.qr(A::$Typ, args...; kwds...) = ArrayLayouts._qr(ArrayLayouts.MemoryLayout(A), axes(A), A, args...; kwds...)
40+
LinearAlgebra.qr!(A::$Typ, args...; kwds...) = ArrayLayouts._qr!(ArrayLayouts.MemoryLayout(A), axes(A), A, args...; kwds...)
41+
LinearAlgebra.lu(A::$Typ, pivot::Union{Val{false}, Val{true}}; kwds...) = ArrayLayouts._lu(ArrayLayouts.MemoryLayout(A), axes(A), A, pivot; kwds...)
42+
LinearAlgebra.lu(A::$Typ{T}; kwds...) where T = ArrayLayouts._lu(ArrayLayouts.MemoryLayout(A), axes(A), A; kwds...)
43+
LinearAlgebra.lu!(A::$Typ, args...; kwds...) = ArrayLayouts._lu!(ArrayLayouts.MemoryLayout(A), axes(A), A, args...; kwds...)
44+
LinearAlgebra.factorize(A::$Typ) = ArrayLayouts._factorize(ArrayLayouts.MemoryLayout(A), axes(A), A)
4545
end)
4646
end
4747

src/ldiv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ end
6363
Rdiv(instantiate(L.A), instantiate(L.B))
6464
end
6565

66-
__ldiv!(::Mat, ::Mat, B) where Mat = error("Overload materialize!(::Ldiv{$(typeof(MemoryLayout(Mat))),$(typeof(MemoryLayout(typeof(B))))})")
66+
__ldiv!(::Mat, ::Mat, B) where Mat = error("Overload materialize!(::Ldiv{$(typeof(MemoryLayout(Mat))),$(typeof(MemoryLayout(B)))})")
6767
__ldiv!(_, F, B) = ldiv!(F, B)
6868
@inline _ldiv!(A, B) = __ldiv!(A, factorize(A), B)
6969
@inline _ldiv!(A::Factorization, B) = ldiv!(A, B)

src/memorylayout.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ rowsupport(_, A, k) = axes(A,2)
462462
463463
gives an iterator containing the possible non-zero entries in the k-th row of A.
464464
"""
465-
rowsupport(A, k) = rowsupport(MemoryLayout(typeof(A)), A, k)
465+
rowsupport(A, k) = rowsupport(MemoryLayout(A), A, k)
466466
rowsupport(A) = rowsupport(A, axes(A,1))
467467

468468
colsupport(_, A, j) = axes(A,1)
@@ -472,7 +472,7 @@ colsupport(_, A, j) = axes(A,1)
472472
473473
gives an iterator containing the possible non-zero entries in the j-th column of A.
474474
"""
475-
colsupport(A, j) = colsupport(MemoryLayout(typeof(A)), A, j)
475+
colsupport(A, j) = colsupport(MemoryLayout(A), A, j)
476476
colsupport(A) = colsupport(A, axes(A,2))
477477

478478
rowsupport(::ZerosLayout, A, _) = 1:0

src/muladd.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,6 @@ end
257257
end
258258
end
259259

260-
261260
@inline materialize!(M::BlasMatMulVecAdd{<:AbstractColumnMajor,<:AbstractStridedLayout,<:AbstractStridedLayout}) =
262261
_gemv!('N', M.α, M.A, M.B, M.β, M.C)
263262
@inline materialize!(M::BlasMatMulVecAdd{<:AbstractRowMajor,<:AbstractStridedLayout,<:AbstractStridedLayout}) =

0 commit comments

Comments
 (0)