Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Tables = "0.2, 1.0"
julia = "1.6"

[extras]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -59,4 +60,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
20 changes: 14 additions & 6 deletions src/interface/data_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,22 @@ function MMI.selectrows(::FI, ::Val{:table}, X, r)
end

function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer})
cols = Tables.columntable(X) # named tuple of vectors
return cols[c]
if !isdataframe(X)
cols = Tables.columntable(X) # named tuple of vectors
return cols[c]
else
return X[!, c]
end
end

function MMI.selectcols(::FI, ::Val{:table}, X, c::AbstractArray)
cols = Tables.columntable(X) # named tuple of vectors
newcols = project(cols, c)
return Tables.materializer(X)(newcols)
if !isdataframe(X)
cols = Tables.columntable(X) # named tuple of vectors
newcols = project(cols, c)
return Tables.materializer(X)(newcols)
else
return X[!, c]
end
end

# -------------------------------
Expand All @@ -124,7 +132,7 @@ function project(t::NamedTuple, indices::AbstractArray{<:Integer})
end

# utils for selectrows
typename(X) = split(string(supertype(typeof(X)).name), '.')[end]
typename(X) = split(string(supertype(typeof(X))), '.')[end]
isdataframe(X) = typename(X) == "AbstractDataFrame"

# ----------------------------------------------------------------
Expand Down
15 changes: 13 additions & 2 deletions test/interface/data_utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import DataFrames

rng = StableRNGs.StableRNG(123)

@testset "categorical" begin
Expand All @@ -23,7 +25,7 @@ end
b = categorical(["a", "b", "c"])
c = categorical(["a", "b", "c"]; ordered=true)
X = (x1=x, x2=z, x3=b, x4=c)
@test MLJModelInterface.scitype(x) == ST.scitype(x)
@test MLJModelInterface.scitype(x) == ST.scitype(x)
@test MLJModelInterface.scitype(y) == ST.scitype(y)
@test MLJModelInterface.scitype(z) == ST.scitype(z)
@test MLJModelInterface.scitype(a) == ST.scitype(a)
Expand All @@ -39,7 +41,7 @@ end
b = categorical(["a", "b", "c"])
c = categorical(["a", "b", "c"]; ordered=true)
X = (x1=x, x2=z, x3=b, x4=c)
@test_throws ArgumentError MLJModelInterface.schema(x)
@test_throws ArgumentError MLJModelInterface.schema(x)
@test MLJModelInterface.schema(X) == ST.schema(X)
end

Expand Down Expand Up @@ -197,4 +199,13 @@ end
@test selectcols(tt, :w) == v
end

# https://github.com/JuliaAI/MLJBase.jl/issues/784
@testset "typename and dataframes" begin
df = DataFrames.DataFrame(x=[1,2,3], y=[2,3,4], z=[4,5,6])
@test MLJBase.typename(df) == "AbstractDataFrame"
@test MLJBase.isdataframe(df)
@test selectrows(df, 2:3) == df[2:3, :]
@test selectcols(df, [:x, :z]) == df[!, [:x, :z]]
end

true