Skip to content

Commit 0a82ccd

Browse files
authored
Merge pull request #992 from JuliaAI/selectcols
Allow `selectcols` to have tuples as "indices" argument
2 parents d65ed1f + d0b5240 commit 0a82ccd

File tree

4 files changed

+24
-7
lines changed

4 files changed

+24
-7
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
with:
3535
version: ${{ matrix.version }}
3636
arch: ${{ matrix.arch }}
37-
- uses: actions/cache@v1
37+
- uses: julia-actions/cache@v1
3838
env:
3939
cache-name: cache-artifacts
4040
with:

src/interface/data_utils.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer})
101101
return Tables.getcolumn(cols, c)
102102
end
103103

104-
function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Colon, AbstractArray})
104+
function MMI.selectcols(::FI, ::Val{:table}, X, c)
105105
if isdataframe(X)
106106
return X[!, c]
107107
else
@@ -115,18 +115,28 @@ end
115115
# utils for `select`*
116116

117117
# project named tuple onto a tuple with only specified `labels` or indices:
118-
function project(t::NamedTuple, labels::AbstractArray{Symbol})
118+
function project(t::NamedTuple, labels::Union{AbstractArray{Symbol},NTuple{<:Any,Symbol}})
119119
return NamedTuple{tuple(labels...)}(t)
120120
end
121121

122122
project(t::NamedTuple, label::Colon) = t
123123
project(t::NamedTuple, label::Symbol) = project(t, [label,])
124124
project(t::NamedTuple, i::Integer) = project(t, [i,])
125125

126-
function project(t::NamedTuple, indices::AbstractArray{<:Integer})
126+
function project(
127+
t::NamedTuple,
128+
indices::AbstractArray{<:Integer},
129+
)
127130
return NamedTuple{tuple(keys(t)[indices]...)}(tuple([t[i] for i in indices]...))
128131
end
129132

133+
function project(
134+
t::NamedTuple,
135+
indices::Tuple{<:Any,Vararg{<:Integer}},
136+
)
137+
return NamedTuple{tuple(keys(t)[[indices...]]...)}(tuple([t[i] for i in indices]...))
138+
end
139+
130140
# utils for selectrows
131141
typename(X) = split(string(supertype(typeof(X))), '.')[end]
132142
isdataframe(X) = typename(X) == "AbstractDataFrame"

test/interface/data_utils.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,9 @@ end
149149
s = schema(tt)
150150
@test nrows(tt) == N
151151

152-
@test selectcols(tt, 4:6) ==
152+
@test selectcols(tt, 4:6) == selectcols(tt, (4, 5, 6)) ==
153+
selectcols(tt, (:x4, :x5, :z)) ==
154+
selectcols(tt, [:x4, :x5, :z]) ==
153155
selectcols(TypedTables.Table(x4=tt.x4, x5=tt.x5, z=tt.z), :)
154156
@test selectcols(tt, [:x1, :z]) ==
155157
selectcols(TypedTables.Table(x1=tt.x1, z=tt.z), :)
@@ -197,6 +199,9 @@ end
197199
v = categorical(collect("asdfasdf"))
198200
tt = TypedTables.Table(v=v, w=v)
199201
@test selectcols(tt, :w) == v
202+
203+
X = (; x1=ones(3), x2=ones(3), x3=ones(3));
204+
selectcols(X, MLJBase.schema(X).names)
200205
end
201206

202207
# https://github.com/JuliaAI/MLJBase.jl/issues/784

test/resampling.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,8 @@ end
862862
measures=[LogLoss(), BrierScore()], verbosity=0)
863863
end
864864

865+
docstring_text = @doc(PerformanceEvaluation) |> string
866+
865867
@testset "reported fields in documentation" begin
866868
# Using `evaluate` to obtain a `PerformanceEvaluation` object.
867869
clf = ConstantClassifier()
@@ -875,11 +877,11 @@ end
875877
cols = ["measure", "operation", "measurement", "1.96*SE", "per_fold"]
876878
@test all(contains.(show_text, cols))
877879
print(show_text)
878-
docstring_text = string(@doc(PerformanceEvaluation))
879880
for fieldname in fieldnames(PerformanceEvaluation)
880881
@test contains(show_text, string(fieldname))
881882
# string(text::Markdown.MD) converts `-` list items to `*`.
882-
@test contains(docstring_text, " * `$fieldname`")
883+
@test contains(docstring_text, "* `$fieldname`") ||
884+
contains(docstring_text, "- `$fieldname`")
883885
end
884886

885887
measures = [LogLoss(), Accuracy()]

0 commit comments

Comments
 (0)