Skip to content

Commit bc694f4

Browse files
authored
allow dispatching on axes for dot (#78)
* allow dispatching on axes for dot * Update Project.toml * add adtrans mul tests
1 parent c9248e0 commit bc694f4

File tree

5 files changed

+39
-8
lines changed

5 files changed

+39
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1515
ArrayLayouts = "0.7"
1616
DomainSets = "0.5"
1717
FillArrays = "0.11, 0.12"
18-
LazyArrays = "0.21.5"
18+
LazyArrays = "0.21.5, 0.22"
1919
StaticArrays = "1"
2020
julia = "1.6"
2121

src/QuasiArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import Base.Broadcast: materialize, materialize!, BroadcastStyle, AbstractArrayS
3838

3939
import LinearAlgebra: transpose, adjoint, checkeltype_adjoint, checkeltype_transpose, Diagonal,
4040
AbstractTriangular, pinv, inv, promote_leaf_eltypes, power_by_squaring,
41-
integerpow, schurpow, tr, factorize, copy_oftype, rank
41+
integerpow, schurpow, tr, factorize, copy_oftype, rank, dot
4242

4343
import ArrayLayouts: indextype, concretize, fillzeros, OnesLayout, AbstractFillLayout, FillLayout, ZerosLayout, diagonallayout, diagonaldata, diagonal
4444
import LazyArrays: MemoryLayout, UnknownLayout, Mul, ApplyLayout, BroadcastLayout,

src/quasiadjtrans.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,21 @@ map(f, tvs::QuasiTransposeAbsVec...) = transpose(map((xs...) -> transpose(f(tran
167167
## multiplication *
168168

169169
# QuasiAdjoint/QuasiTranspose-vector * vector
170+
171+
# allow re-expansion using Legendre
172+
_dot(ax, a, b) = Base.invoke(dot, NTuple{2,Any}, a, b)
173+
174+
function dot(a::AbstractQuasiArray, b::AbstractQuasiArray)
175+
ax = axes(a,1)
176+
ax == axes(b,1) || throw(DimensionMismatch())
177+
_dot(ax, a, b)
178+
end
179+
170180
*(u::QuasiAdjointAbsVec, v::AbstractQuasiVector) = dot(u.parent, v)
171181
*(u::QuasiTransposeAbsVec{T}, v::AbstractQuasiVector{T}) where {T<:Real} = dot(u.parent, v)
172182
function *(u::QuasiTransposeAbsVec, v::AbstractQuasiVector)
173-
@assert !has_offset_axes(u, v)
174-
@boundscheck length(u) == length(v) || throw(DimensionMismatch())
175-
return sum(@inbounds(u[k]*v[k]) for k in 1:length(u))
183+
@boundscheck axes(u,2) == axes(v,1) || throw(DimensionMismatch())
184+
return sum(@inbounds(u.parent[k]*v[k]) for k in axes(v,1))
176185
end
177186
# vector * QuasiAdjoint/QuasiTranspose-vector
178187
*(u::AbstractQuasiVector, v::AdjOrTransAbsVec) = broadcast(*, u, v)
@@ -229,4 +238,18 @@ arguments(LAY::ApplyLayout{typeof(*)}, V::QuasiTranspose) = reverse(transpose.(a
229238

230239
# This is used in ContinuumArrays.jl to ensure x' is lazy
231240
BroadcastStyle(::Type{<:QuasiAdjoint{<:Any,<:Inclusion}}) = LazyQuasiArrayStyle{2}()
232-
BroadcastStyle(::Type{<:QuasiTranspose{<:Any,<:Inclusion}}) = LazyQuasiArrayStyle{2}()
241+
BroadcastStyle(::Type{<:QuasiTranspose{<:Any,<:Inclusion}}) = LazyQuasiArrayStyle{2}()
242+
243+
244+
245+
###
246+
# adjoint concat support
247+
###
248+
249+
arguments(::ApplyLayout{typeof(vcat)}, A::QuasiAdjoint) = map(adjoint, arguments(ApplyLayout{typeof(hcat)}(), parent(A)))
250+
arguments(::ApplyLayout{typeof(hcat)}, A::QuasiAdjoint) = map(adjoint, arguments(ApplyLayout{typeof(vcat)}(), parent(A)))
251+
arguments(::ApplyLayout{typeof(vcat)}, A::QuasiTranspose) = map(transpose, arguments(ApplyLayout{typeof(hcat)}(), parent(A)))
252+
arguments(::ApplyLayout{typeof(hcat)}, A::QuasiTranspose) = map(transpose, arguments(ApplyLayout{typeof(vcat)}(), parent(A)))
253+
254+
255+
copy(M::Mul{ApplyLayout{typeof(vcat)},QuasiArrayLayout}) = vcat((arguments(vcat, M.A) .* Ref(M.B))...)

src/quasiconcat.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,4 @@ function _getindex(::Type{IND}, A::UnionVcat{T,2}, (x,j)::IND) where {IND,T}
7171
x in axes(a,1) && return convert(T,a[x,j])::T
7272
end
7373
throw(BoundsError(A, I))
74-
end
74+
end

test/test_matmul.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using QuasiArrays, ArrayLayouts, Test
1+
using QuasiArrays, ArrayLayouts, LinearAlgebra, Test
22
import QuasiArrays: apply
33

44
@testset "Multiplication" begin
@@ -35,4 +35,12 @@ import QuasiArrays: apply
3535
@test Array(A*A) Array(A)^2
3636
@test Array(A*A*A) Array(A)^3
3737
end
38+
39+
@testset "absvec" begin
40+
a = QuasiArray(rand(3),(0:0.5:1,))
41+
c = QuasiArray(rand(3) .+ im .* randn(3),(0:0.5:1,))
42+
@test a'a transpose(a)a dot(a,a)
43+
@test c'c dot(c,c)
44+
@test transpose(c)c transpose(c.parent)*c.parent
45+
end
3846
end

0 commit comments

Comments
 (0)