Skip to content

Commit 08e3122

Browse files
authored
fixes to select following anthony's comments (#6)
* fixes to select following anthony's comments * patch release
1 parent 3736e2a commit 08e3122

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"

src/data_utils.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,16 @@ depending on the value type. See also: [`selectrows`](@ref),
248248
"""
249249
select(X, r, c) = select(get_interface_mode(), vtrait(X), X, r, c)
250250

251-
select(::Mode, ::Val, X, r, c) = selectcols(selectrows(X, r), c)
251+
# only used here to denote "group of indices"
252+
const MIdx = Union{AbstractArray,Colon}
253+
254+
select(::Mode, ::Val, X, r::MIdx, c) = selectcols(selectrows(X, r), c)
255+
select(::Mode, ::Val, X, r, c::MIdx) = selectcols(selectrows(X, r), c)
256+
select(::Mode, ::Val, X, r::MIdx, c::MIdx) = selectcols(selectrows(X, r), c)
257+
select(::Mode, ::Val, X, r, c) = _squeeze(selectcols(selectrows(X, r), c))
258+
259+
_squeeze(::Nothing) = nothing
260+
_squeeze(v) = first(v)
252261

253262
# ------------------------------------------------------------------------
254263
# UnivariateFinite

test/data_utils.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ end
9898
X = nothing
9999
@test selectrows(X, 1) === nothing
100100
@test selectcols(X, 1) === nothing
101-
@test select(X, 1, 2) === nothing
101+
@test select(X, 1, 2) === nothing
102102

103103
# vector
104104
X = ones(5)
@@ -116,7 +116,7 @@ end
116116
@test selectcols(X, 1) == ones(5,)
117117
@test selectcols(X, 1:2) == ones(5, 2)
118118
@test selectcols(X, :) === X
119-
@test select(X, 1, 1) == [1.0]
119+
@test select(X, 1, 1) == 1.0
120120
@test select(X, 1:2, 1) == ones(2,)
121121
@test select(X, 1:2, 1:2) == ones(2, 2)
122122

@@ -131,7 +131,6 @@ end
131131
@test_throws ArgumentError selectcols(X, 1)
132132
@test_throws ArgumentError select(X, 1, 1)
133133
end
134-
# ------------------------------------------------------------------------
135134
@testset "select-full" begin
136135
setfull()
137136
M.selectrows(::FI, ::Val{:table}, X, ::Colon) = X
@@ -167,13 +166,31 @@ end
167166
@test selectcols(X, 1) == [1,2,3]
168167
@test selectcols(X, 1:2) == (x = [1, 2, 3], y = [4, 5, 6])
169168
@test selectcols(X, :) === X
170-
@test select(X, 1, 1) == [1]
169+
@test select(X, 1, 1) == 1
171170
@test select(X, 1:2, 1) == [1,2]
172171
@test select(X, :, 1) == [1,2,3]
173172
@test selectcols(X, :x) == [1,2,3]
174173
@test select(X, 1:2, :z) == [0,0]
175-
end
174+
#
175+
# extra tests by Anthony
176+
X = (x=[1,2,3], y=[10, 20, 30], z= [:a, :b, :c])
177+
@test select(X, 2, :y) == 20
178+
@test select(X, 2, [:x, :y]) == (x=[2,], y=[20,])
179+
@test select(X, 2:3, :x) == [2, 3]
180+
@test select(X, 2:3, [:x, :y]) == (x=[2, 3], y=[20, 30])
181+
@test select(X, :, [:x, :y]) == select(X, 1:3, [:x, :y])
182+
@test select(X, 2, :) == select(X, 2, 1:3)
183+
@test select(X, 2:3, :) == select(X, 2:3, 1:3)
176184

185+
@test select(X, 2, 2) == 20
186+
@test select(X, 2, [1, 2]) == (x=[2,], y=[20,])
187+
@test select(X, 2:3, 1) == [2, 3]
188+
@test select(X, 2:3, [1, 2]) == (x=[2, 3], y=[20, 30])
189+
@test select(X, :, [1, 2]) == select(X, 1:3, [1, 2])
190+
@test select(X, 2, :) == select(X, 2, 1:3)
191+
@test select(X, 2:3, :) == select(X, 2:3, 1:3)
192+
end
193+
# ------------------------------------------------------------------------
177194
@testset "univ-finite" begin
178195
setlight()
179196
@test_throws M.InterfaceError UnivariateFinite(Dict(2=>3,3=>4))

0 commit comments

Comments
 (0)