Skip to content

Commit 3368d85

Browse files
authored
Allow to pass multiple predicates in Cols and mix them with other selectors (#3279)
1 parent cf893d2 commit 3368d85

File tree

3 files changed

+27
-10
lines changed

3 files changed

+27
-10
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
* Add support for `operator` keyword argument in `Cols`
1313
to take a set operation to apply to passed selectors (`union` by default)
1414
([3224](https://github.com/JuliaData/DataFrames.jl/pull/3224))
15+
* Allow to pass multiple predicates in `Cols` and mix them with
16+
other selectors
17+
([3279](https://github.com/JuliaData/DataFrames.jl/pull/3279))
1518
* Improve support for setting group order in `groupby`
1619
([3253](https://github.com/JuliaData/DataFrames.jl/pull/3253))
1720
* Joining functions now support `order` keyword argument allowing the user

src/other/index.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,16 @@ end
230230
@inline Base.getindex(x::AbstractIndex, idx::All) =
231231
isempty(idx.cols) ? (1:length(x)) : throw(ArgumentError("All(args...) is not supported: use Cols(args...) instead"))
232232

233+
@inline _getindex_cols(x::AbstractIndex, idx::Any) = x[idx]
234+
@inline _getindex_cols(x::AbstractIndex, idx::Function) = findall(idx, names(x))
235+
# the definition below is needed because `:` is a Function
236+
@inline _getindex_cols(x::AbstractIndex, idx::Colon) = x[idx]
237+
233238
@inline function Base.getindex(x::AbstractIndex, idx::Cols)
234239
isempty(idx.cols) && return Int[]
235-
return idx.operator(getindex.(Ref(x), idx.cols)...)
240+
return idx.operator(_getindex_cols.(Ref(x), idx.cols)...)
236241
end
237242

238-
# the definition below is needed because `:` is a Function
239-
@inline Base.getindex(x::AbstractIndex, idx::Cols{Tuple{typeof(:)}}) = x[:]
240-
@inline Base.getindex(x::AbstractIndex, idx::Cols{<:Tuple{Function}}) =
241-
findall(idx.cols[1], names(x))
242-
243243
@inline function Base.getindex(x::AbstractIndex, idx::AbstractVector{<:Integer})
244244
if any(v -> v isa Bool, idx)
245245
throw(ArgumentError("Bool values except for AbstractVector{Bool} are not " *

test/index.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,10 @@ end
477477
@test df[:, Cols(x -> x[1] == 'a')] == df[:, [1, 2]]
478478
@test df[:, Cols(x -> x[end] == '1')] == df[:, [1, 3]]
479479
@test df[:, Cols(x -> x[end] == '3')] == DataFrame()
480-
@test_throws ArgumentError df[:, Cols(x -> true, 1)]
481-
@test_throws ArgumentError df[:, Cols(1, x -> true)]
480+
@test df[:, Cols(x -> true, 1)] == df
481+
@test df[:, Cols(1, x -> true)] == df
482+
@test df[:, Cols(x -> true, 1, operator=intersect)] == DataFrame(a1=1)
483+
@test df[:, Cols(1, x -> true, operator=intersect)] == DataFrame(a1=1)
482484

483485
@test ncol(select(df, Cols(operator=intersect))) == 0
484486
@test ncol(df[:, Cols(operator=intersect)]) == 0
@@ -539,8 +541,20 @@ end
539541
@test df[:, Cols(x -> x[1] == 'a', operator=intersect)] == df[:, [1, 2]]
540542
@test df[:, Cols(x -> x[end] == '1', operator=intersect)] == df[:, [1, 3]]
541543
@test df[:, Cols(x -> x[end] == '3', operator=intersect)] == DataFrame()
542-
@test_throws ArgumentError df[:, Cols(x -> true, 1, operator=intersect)]
543-
@test_throws ArgumentError df[:, Cols(1, x -> true, operator=intersect)]
544+
@test df[:, Cols(x -> true, 1, operator=intersect)] == df[:, 1:1]
545+
@test df[:, Cols(1, x -> true, operator=intersect)] == df[:, 1:1]
546+
547+
@test df[:, Cols(startswith("a"), endswith("2"))] ==
548+
select(df, Cols(startswith("a"), endswith("2"))) ==
549+
df[:, ["a1", "a2", "b2"]]
550+
@test df[:, Cols(startswith("a"), endswith("2"), operator=intersect)] ==
551+
df[:, Cols(startswith("a"), :, endswith("2"), operator=intersect)] ==
552+
select(df, Cols(startswith("a"), endswith("2"), operator=intersect)) ==
553+
df[:, ["a2"]]
554+
@test df[:, Cols(startswith("a"), endswith("2"), operator=setdiff)] ==
555+
select(df, Cols(startswith("a"), endswith("2"), operator=setdiff)) ==
556+
df[:, ["a1"]]
557+
@test df[:, Cols(startswith("a"), endswith("2"), :, operator=setdiff)] == DataFrame()
544558
end
545559

546560
@testset "views" begin

0 commit comments

Comments
 (0)