Skip to content

Commit c30b8cd

Browse files
authored
Add Trues type and methods (#110)
* Add Trues * Export * Test Trues * Add comments * Bump release * Check axes * More strings in DimensionMismatch * Use BoundsError * Refine errors * Use DimensionMismatch * Test throws * Restrict getindex to OneTo axes * Revert strings in throw * Restrict setindex! to OneTo
1 parent 6a1b630 commit c30b8cd

File tree

4 files changed

+70
-4
lines changed

4 files changed

+70
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.8.14"
3+
version = "0.9"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/FillArrays.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
1515

1616

1717

18-
export Zeros, Ones, Fill, Eye
18+
export Zeros, Ones, Fill, Eye, Trues, Falses
1919

2020
"""
2121
AbstractFill{T, N, Axes} <: AbstractArray{T, N}
@@ -269,11 +269,11 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
269269
$Typ{T}(length.(kj))
270270
end
271271
function getindex(A::$Typ{T}, kr::AbstractVector{Bool}) where T
272-
length(A) == length(kr) || throw(DimensionMismatch())
272+
length(A) == length(kr) || throw(DimensionMismatch("lengths must match"))
273273
$Typ{T}(count(kr))
274274
end
275275
function getindex(A::$Typ{T}, kr::AbstractArray{Bool}) where T
276-
size(A) == size(kr) || throw(DimensionMismatch())
276+
size(A) == size(kr) || throw(DimensionMismatch("sizes must match"))
277277
$Typ{T}(count(kr))
278278
end
279279

@@ -563,6 +563,7 @@ count(f, x::AbstractFill) = f(getindex_value(x)) ? length(x) : 0
563563

564564
include("fillalgebra.jl")
565565
include("fillbroadcast.jl")
566+
include("trues.jl")
566567

567568
##
568569
# print

src/trues.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
"""
3+
Trues = Ones{Bool, N, Axes} where {N, Axes}
4+
5+
Lazy version of `trues` with axes.
6+
Typically created using `Trues(dims)` or `Trues(dims...)`
7+
8+
# Example
9+
```jldoctest
10+
julia> Trues(1,3)
11+
1×3 Ones{Bool,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}} = true
12+
13+
julia> Trues((2,3))
14+
2×3 Ones{Bool,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}} = true
15+
```
16+
"""
17+
const Trues = Ones{Bool, N, Axes} where {N, Axes}
18+
19+
20+
""" `Falses = Zeros{Bool, N, Axes}` (see `Trues`) """
21+
const Falses = Zeros{Bool, N, Axes} where {N, Axes}
22+
23+
24+
# y[mask] = x when mask isa Trues (cf y[:] = x)
25+
# Supported here only for arrays with standard OneTo axes.
26+
function Base.setindex!(y::AbstractArray{T,N}, x,
27+
mask::Trues{N, NTuple{N,Base.OneTo{Int}}},
28+
) where {T,N}
29+
if axes(x) isa NTuple{N,Base.OneTo{Int}} &&
30+
axes(y) isa NTuple{N,Base.OneTo{Int}}
31+
@boundscheck size(y) == size(mask) || throw(BoundsError(y, mask))
32+
@boundscheck size(x) == size(mask) || throw(DimensionMismatch(
33+
"tried to assign $(length(x)) elements to $(length(y)) destinations"))
34+
@boundscheck checkbounds(y, mask)
35+
return copyto!(y, x)
36+
end
37+
return setindex!(y, x, trues(size(mask))) # fall back on usual setindex!
38+
end
39+
40+
# x[mask] when mask isa Trues (cf x[trues(size(x))] or x[:])
41+
# Supported here only for arrays with standard OneTo axes.
42+
function Base.getindex(x::AbstractArray{T,N},
43+
mask::Trues{N, NTuple{N,Base.OneTo{Int}}},
44+
) where {T,N}
45+
if axes(x) isa NTuple{N,Base.OneTo{Int}} where N
46+
@boundscheck size(x) == size(mask) || throw(BoundsError(x, mask))
47+
return vec(x)
48+
end
49+
return x[trues(size(x))] # else revert to usual getindex method
50+
end

test/runtests.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,3 +1015,18 @@ end
10151015
@test axes(U) == (Base.OneTo(3),Base.OneTo(3))
10161016
end
10171017
end
1018+
1019+
@testset "Trues" begin
1020+
@test Trues(2,3) == Trues((2,3)) == trues(2,3)
1021+
@test Falses(2,3) == Falses((2,3)) == falses(2,3)
1022+
dim = (4,5)
1023+
mask = Trues(dim)
1024+
x = randn(dim)
1025+
@test x[mask] == vec(x) # getindex
1026+
y = similar(x)
1027+
y[mask] = x # setindex!
1028+
@test y == x
1029+
@test_throws BoundsError ones(3)[Trues(2)]
1030+
@test_throws BoundsError setindex!(ones(3), zeros(3), Trues(2))
1031+
@test_throws DimensionMismatch setindex!(ones(2), zeros(3), Trues(2))
1032+
end

0 commit comments

Comments
 (0)