Skip to content

Commit 936a633

Browse files
authored
fix rowvector methods (#193)
1 parent 835a716 commit 936a633

File tree

2 files changed

+66
-44
lines changed

2 files changed

+66
-44
lines changed

src/LinearAlgebra/rowvector.jl

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ import Base: convert, similar, length, size, axes, IndexStyle,
99
RowVector(vector)
1010
1111
A lazy-view wrapper of an `AbstractVector`, which turns a length-`n` vector into a `1×n`
12-
shaped row vector and represents the transpose of a vector (the elements are also transposed
13-
recursively).
12+
shaped row vector and represents the transpose of a vector (although unlike `transpose`,
13+
the elements are *not* transposed recursively).
1414
1515
By convention, a vector can be multiplied by a matrix on its left (`A * v`) whereas a row
16-
vector can be multiplied by a matrix on its right (such that `transpose(v) * A = transpose(transpose(A) * v)`). It
16+
vector can be multiplied by a matrix on its right (such that `RowVector(v) * A = RowVector(transpose(A) * v)`). It
1717
differs from a `1×n`-sized matrix by the facts that its transpose returns a vector and the
18-
inner product `transpose(v1) * v2` returns a scalar, but will otherwise behave similarly.
18+
inner product `RowVector(v1) * v2` returns a scalar, but will otherwise behave similarly.
1919
"""
2020
struct RowVector{T,V<:AbstractVector} <: AbstractMatrix{T}
2121
vec::V
@@ -29,13 +29,13 @@ end
2929
@inline RowVector{T}(vec::AbstractVector{T}) where {T} = RowVector{T,typeof(vec)}(vec)
3030

3131
# Constructors that take a size and default to Array
32-
@inline RowVector{T}(n::Int) where {T} = RowVector{T}(Vector{T}(n))
32+
@inline RowVector{T}(n::Int) where {T} = RowVector{T}(Vector{T}(undef,n))
3333
@inline RowVector{T}(n1::Int, n2::Int) where {T} = n1 == 1 ?
34-
RowVector{T}(Vector{T}(n2)) :
34+
RowVector{T}(n2) :
3535
error("RowVector expects 1×N size, got ($n1,$n2)")
36-
@inline RowVector{T}(n::Tuple{Int}) where {T} = RowVector{T}(Vector{T}(n[1]))
36+
@inline RowVector{T}(n::Tuple{Int}) where {T} = RowVector{T}(n[1])
3737
@inline RowVector{T}(n::Tuple{Int,Int}) where {T} = n[1] == 1 ?
38-
RowVector{T}(Vector{T}(n[2])) :
38+
RowVector{T}(n[2]) :
3939
error("RowVector expects 1×N size, got $n")
4040

4141
# Conversion of underlying storage
@@ -44,7 +44,7 @@ convert(::Type{RowVector{T,V}}, rowvec::RowVector) where {T,V<:AbstractVector} =
4444

4545
# similar tries to maintain the RowVector wrapper and the parent type
4646
@inline similar(rowvec::RowVector) = RowVector(similar(parent(rowvec)))
47-
@inline similar(rowvec::RowVector, ::Type{T}) where {T} = RowVector(similar(parent(rowvec), transpose_type(T)))
47+
@inline similar(rowvec::RowVector, ::Type{T}) where {T} = RowVector(similar(parent(rowvec), T))
4848

4949
# Resizing similar currently loses its RowVector property.
5050
@inline similar(rowvec::RowVector, ::Type{T}, dims::Dims{N}) where {T,N} = similar(parent(rowvec), T, dims)
@@ -54,40 +54,17 @@ parent(rowvec::RowVector) = rowvec.vec
5454
# AbstractArray interface
5555
@inline length(rowvec::RowVector) = length(rowvec.vec)
5656
@inline size(rowvec::RowVector) = (1, length(rowvec.vec))
57-
@inline size(rowvec::RowVector, d) = ifelse(d==2, length(rowvec.vec), 1)
58-
@inline axes(rowvec::RowVector) = (Base.OneTo(1), axes(rowvec.vec)[1])
59-
@inline axes(rowvec::RowVector, d) = ifelse(d == 2, axes(rowvec.vec)[1], Base.OneTo(1))
57+
@inline axes(rowvec::RowVector) = (Base.OneTo(1), axes(rowvec.vec, 1))
6058
IndexStyle(::RowVector) = IndexLinear()
6159
IndexStyle(::Type{<:RowVector}) = IndexLinear()
6260

63-
64-
@propagate_inbounds getindex(rowvec::RowVector, i) = rowvec.vec[i]
65-
@propagate_inbounds setindex!(rowvec::RowVector, v, i) = setindex!(rowvec.vec, v, i)
66-
67-
# Cartesian indexing is distorted by getindex
68-
# Furthermore, Cartesian indexes don't have to match shape, apparently!
69-
@inline function getindex(rowvec::RowVector, i::CartesianIndex)
70-
@boundscheck if !(i.I[1] == 1 && i.I[2] axes(rowvec.vec)[1] && check_tail_indices(i.I...))
71-
throw(BoundsError(rowvec, i.I))
72-
end
73-
@inbounds return rowvec.vec[i.I[2]]
74-
end
75-
@inline function setindex!(rowvec::RowVector, v, i::CartesianIndex)
76-
@boundscheck if !(i.I[1] == 1 && i.I[2] axes(rowvec.vec)[1] && check_tail_indices(i.I...))
77-
throw(BoundsError(rowvec, i.I))
78-
end
79-
@inbounds rowvec.vec[i.I[2]] = v
80-
end
81-
82-
@propagate_inbounds getindex(rowvec::RowVector, ::CartesianIndex{0}) = getindex(rowvec)
83-
@propagate_inbounds getindex(rowvec::RowVector, i::CartesianIndex{1}) = getindex(rowvec, i.I[1])
84-
85-
@propagate_inbounds setindex!(rowvec::RowVector, v, ::CartesianIndex{0}) = setindex!(rowvec, v)
86-
@propagate_inbounds setindex!(rowvec::RowVector, v, i::CartesianIndex{1}) = setindex!(rowvec, v, i.I[1])
61+
@propagate_inbounds getindex(rowvec::RowVector, i::Int) = rowvec.vec[i]
62+
@propagate_inbounds setindex!(rowvec::RowVector, v, i::Int) = setindex!(rowvec.vec, v, i)
8763

8864
# helper function for below
89-
@inline to_vec(rowvec::RowVector) = rowvec.vec
90-
65+
to_vec(r::RowVector) = parent(r)
66+
to_vec(x) = x
67+
@inline to_vecs(rowvecs...) = map(to_vec, rowvecs)
9168
# map: Preserve the RowVector by un-wrapping and re-wrapping, but note that `f`
9269
# expects to operate within the transposed domain, so to_vec transposes the elements
9370
@inline map(f, rowvecs::RowVector...) = RowVector(map(f, to_vecs(rowvecs...)...))
@@ -98,13 +75,13 @@ end
9875

9976
# Horizontal concatenation #
10077

101-
@inline hcat(X::RowVector...) = RowVector(vcat(X...))
102-
@inline hcat(X::Union{RowVector,Number}...) = RowVector(vcat(X...))
78+
@inline hcat(X::RowVector...) = RowVector(mapreduce(parent, vcat, X))
79+
@inline hcat(X::Union{RowVector,Number}...) = RowVector(mapreduce(to_vec, vcat, X))
10380

10481
@inline typed_hcat(::Type{T}, X::RowVector...) where {T} =
105-
RowVector(typed_vcat(T, X...))
82+
RowVector(Base.typed_vcat(T, to_vecs(X...)...))
10683
@inline typed_hcat(::Type{T}, X::Union{RowVector,Number}...) where {T} =
107-
RowVector(typed_vcat(T, X...))
84+
RowVector(Base.typed_vcat(T, to_vecs(X...)...))
10885

10986
# Multiplication #
11087

@@ -116,9 +93,9 @@ end
11693
if length(rowvec) != length(vec)
11794
throw(DimensionMismatch("A has dimensions $(size(rowvec)) but B has dimensions $(size(vec))"))
11895
end
119-
sum(@inbounds(rowvec[i]*vec[i]) for i = 1:length(vec))
96+
sum(@inbounds(rowvec[i]*vec[i]) for i in eachindex(rowvec, vec))
12097
end
121-
@inline *(rowvec::RowVector, mat::AbstractMatrix) = RowVector(transpose(transpose()mat) * rowvec.vec)
98+
@inline *(rowvec::RowVector, mat::AbstractMatrix) = RowVector(transpose(mat) * parent(rowvec))
12299
*(::RowVector, ::RowVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
123100
@inline *(vec::AbstractVector, rowvec::RowVector) = vec .* rowvec
124101

test/runtests.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,51 @@ end
270270
end
271271
end
272272

273+
@testset "RowVector" begin
274+
@testset "constructors" begin
275+
for s in [(2,), (1,2)], sz in [s, (s,)]
276+
b = fill!(ApproxFunBase.RowVector{Int}(sz...), 2)
277+
@test size(b) == (1,2)
278+
@test all(==(2), b)
279+
end
280+
end
281+
# for a vector of numbers, RowVector should be identical to transpose
282+
a = Float64.(1:4)
283+
at = transpose(a)
284+
b = ApproxFunBase.RowVector(a)
285+
@test b == at
286+
for inds in [eachindex(b), CartesianIndices(b)]
287+
for i in inds
288+
@test b[i] == at[i]
289+
end
290+
end
291+
M = Float64.(reshape(1:16, 4, 4))
292+
@test b * M == at * M
293+
@test b * a == at * a
294+
@test b * Float32.(a) == at * Float32.(a)
295+
@test a * b == a * at
296+
@test map(x->x^2, b) == map(x->x^2, at)
297+
@test b.^2 == at.^2
298+
@test hcat(b, b) == hcat(at, at)
299+
@test vcat(b, b) == vcat(at, at)
300+
@test hcat(b, 1) == hcat(at, 1)
301+
c = Float32[b 1]
302+
@test eltype(c) == Float32
303+
@test c == [b 1]
304+
c = Float32[b b]
305+
@test eltype(c) == Float32
306+
@test c == Float32[at at]
307+
308+
# setindex
309+
b[2] = 30
310+
@test b[2] == b[1,2] == 30
311+
@test a[2] == 30
312+
313+
a = rand(1)
314+
b = ApproxFunBase.RowVector(a)
315+
@test b[] == b[CartesianIndex()] == b[CartesianIndex(1)] == a[]
316+
end
317+
273318
@testset "misc" begin
274319
a = @inferred ApproxFunBase.specialfunctionnormalizationpoint(exp,real,Fun())
275320
@test a[1] == 1

0 commit comments

Comments
 (0)