diff --git a/Project.toml b/Project.toml index d604fb6..9666236 100644 --- a/Project.toml +++ b/Project.toml @@ -1,34 +1,41 @@ name = "MLUtils" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" authors = ["Carlo Lucibello and contributors"] -version = "0.2.11" +version = "0.2.12" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] ChainRulesCore = "1.0" +DataAPI = "1.0" DelimitedFiles = "1.0" FLoops = "0.2" FoldsThreads = "0.1" +SimpleTraits = "0.9" ShowCases = "0.1" StatsBase = "0.33" +Tables = "1.10" Transducers = "0.4" julia = "1.6" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ChainRulesTestUtils", "SparseArrays", "Test", "Zygote"] +test = ["ChainRulesTestUtils", "DataFrames", "SparseArrays", "Test", "Zygote"] diff --git a/src/MLUtils.jl b/src/MLUtils.jl index 143ff09..40c6aa5 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -8,12 +8,19 @@ using FLoops.Transducers: Executor, ThreadedEx using FoldsThreads: TaskPoolEx import StatsBase: sample using Transducers +using Tables +using DataAPI using Base: @propagate_inbounds using Random: AbstractRNG, shuffle!, GLOBAL_RNG, rand!, randn! import ChainRulesCore: rrule using ChainRulesCore: @non_differentiable, unthunk, AbstractZero, NoTangent, ZeroTangent, ProjectTo +using SimpleTraits + +@traitdef IsTable{X} +@traitimpl IsTable{X} <- Tables.istable(X) + include("observation.jl") export numobs, diff --git a/src/observation.jl b/src/observation.jl index ad99bf4..4abffbd 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -3,30 +3,68 @@ Return the total number of observations contained in `data`. -If `data` does not have `numobs` defined, then this function -falls back to `length(data)`. +If `data` does not have `numobs` defined, +then in the case of `Tables.table(data) == true` +returns the number of rows, otherwise returns `length(data)`. + Authors of custom data containers should implement `Base.length` for their type instead of `numobs`. `numobs` should only be implemented for types where there is a difference between `numobs` and `Base.length` (such as multi-dimensional arrays). -See also [`getobs`](@ref) +`getobs` supports by default nested combinations of array, tuple, +named tuples, and dictionaries. + +See also [`getobs`](@ref). + +# Examples +```jldoctest + +# named tuples +x = (a = [1, 2, 3], b = rand(6, 3)) +numobs(x) == 3 + +# dictionaries +x = Dict(:a => [1, 2, 3], :b => rand(6, 3)) +numobs(x) == 3 +``` +All internal containers must have the same number of observations: +```juliarepl +julia> x = (a = [1, 2, 3, 4], b = rand(6, 3)); + +julia> numobs(x) +ERROR: DimensionMismatch: All data containers must have the same number of observations. +Stacktrace: + [1] _check_numobs_error() + @ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:163 + [2] _check_numobs + @ ~/.julia/dev/MLUtils/src/observation.jl:130 [inlined] + [3] numobs(data::NamedTuple{(:a, :b), Tuple{Vector{Int64}, Matrix{Float64}}}) + @ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:177 + [4] top-level scope + @ REPL[35]:1 +``` """ function numobs end # Generic Fallbacks -numobs(data) = length(data) +@traitfn numobs(data::X) where {X; IsTable{X}} = DataAPI.nrow(data) +@traitfn numobs(data::X) where {X; !IsTable{X}} = length(data) + """ getobs(data, [idx]) -Return the observations corresponding to the observation-index `idx`. +Return the observations corresponding to the observation index `idx`. Note that `idx` can be any type as long as `data` has defined -`getobs` for that type. +`getobs` for that type. If `idx` is not provided, then materialize +all observations in `data`. + +If `data` does not have `getobs` defined, +then in the case of `Tables.table(data) == true` +returns the row(s) in position `idx`, otherwise returns `data[idx]`. -If `data` does not have `getobs` defined, then this function -falls back to `data[idx]`. Authors of custom data containers should implement `Base.getindex` for their type instead of `getobs`. `getobs` should only be implemented for types where there is a @@ -40,13 +78,37 @@ Every author behind some custom data container can make this decision themselves. The output should be consistent when `idx` is a scalar vs vector. -See also [`getobs!`](@ref) and [`numobs`](@ref) +`getobs` supports by default nested combinations of array, tuple, +named tuples, and dictionaries. + +See also [`getobs!`](@ref) and [`numobs`](@ref). + +# Examples + +```jldoctest +# named tuples +x = (a = [1, 2, 3], b = rand(6, 3)) + +getobs(x, 2) == (a = 2, b = x.b[:, 2]) +getobs(x, [1, 3]) == (a = [1, 3], b = x.b[:, [1, 3]]) + + +# dictionaries +x = Dict(:a => [1, 2, 3], :b => rand(6, 3)) + +getobs(x, 2) == Dict(:a => 2, :b => x[:b][:, 2]) +getobs(x, [1, 3]) == Dict(:a => [1, 3], :b => x[:b][:, [1, 3]]) +``` """ function getobs end # Generic Fallbacks + getobs(data) = data -getobs(data, idx) = data[idx] + +@traitfn getobs(data::X, idx) where {X; IsTable{X}} = Tables.subset(data, idx, viewhint=false) +@traitfn getobs(data::X, idx) where {X; !IsTable{X}} = data[idx] + """ getobs!(buffer, data, idx) @@ -61,6 +123,8 @@ method is provided for the type of `data`, then `buffer` will be because the type of `data` may not lend itself to the concept of `copy!`. Thus, supporting a custom `getobs!` is optional and not required. + +See also [`getobs`](@ref) and [`numobs`](@ref). """ function getobs! end # getobs!(buffer, data) = getobs(data) @@ -161,3 +225,5 @@ function getobs!(buffers, data::Dict, i) return buffers end + + diff --git a/test/observation.jl b/test/observation.jl index a0369bc..7907f7f 100644 --- a/test/observation.jl +++ b/test/observation.jl @@ -190,4 +190,12 @@ end @test getobs!((nothing,xbuf),(Xs,X), 3:4) == (getobs(Xs,3:4),xbuf) @test xbuf == getobs(X,3:4) end + + @testset "tables" begin + df = DataFrame(a=[1,2,3], y=["a","b","c"]) + @test getobs(df) == df + @test getobs(df, 1) == df[1,:] + @test getobs(df, 2:3) == df[2:3,:] + @test numobs(df) == 3 + end end diff --git a/test/runtests.jl b/test/runtests.jl index 5eee979..a88d66f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using MLUtils using MLUtils.Datasets using MLUtils: RingBuffer, eachobsparallel +using MLUtils: flatten, stack, unstack # also exported by DataFrames.jl using SparseArrays using Random, Statistics using Test @@ -9,6 +10,7 @@ using FoldsThreads: TaskPoolEx using ChainRulesTestUtils: test_rrule using Zygote: ZygoteRuleConfig using ChainRulesCore: rrule_via_ad +using DataFrames showcompact(io, x) = show(IOContext(io, :compact => true), x)