From de5bb39d86a22644f3a6cb20a54145225cb570a2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 15 Oct 2022 11:17:09 +0200 Subject: [PATCH 1/5] support tables --- Project.toml | 9 +++++++-- src/MLUtils.jl | 1 + src/observation.jl | 4 ++-- test/observation.jl | 7 +++++++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index d604fb6..6f146be 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "MLUtils" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" authors = ["Carlo Lucibello and contributors"] -version = "0.2.11" +version = "0.2.12" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443" @@ -12,23 +13,27 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] ChainRulesCore = "1.0" +DataAPI = "1.0" DelimitedFiles = "1.0" FLoops = "0.2" FoldsThreads = "0.1" ShowCases = "0.1" StatsBase = "0.33" +Tables = "1.10" Transducers = "0.4" julia = "1.6" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ChainRulesTestUtils", "SparseArrays", "Test", "Zygote"] +test = ["ChainRulesTestUtils", "DataFrames", "SparseArrays", "Test", "Zygote"] diff --git a/src/MLUtils.jl b/src/MLUtils.jl index 143ff09..bc10837 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -8,6 +8,7 @@ using FLoops.Transducers: Executor, ThreadedEx using FoldsThreads: TaskPoolEx import StatsBase: sample using Transducers +using Tables using Base: @propagate_inbounds using Random: AbstractRNG, shuffle!, GLOBAL_RNG, rand!, randn! import ChainRulesCore: rrule diff --git a/src/observation.jl b/src/observation.jl index ad99bf4..1c48f7f 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -16,7 +16,7 @@ See also [`getobs`](@ref) function numobs end # Generic Fallbacks -numobs(data) = length(data) +numobs(data) = Tables.istable(data) ? DataAPI.nrow(x) : length(data) """ getobs(data, [idx]) @@ -46,7 +46,7 @@ function getobs end # Generic Fallbacks getobs(data) = data -getobs(data, idx) = data[idx] +getobs(data, idx) = Tables.istable(data) ? Tables.subset(data, idx) : data[idx] """ getobs!(buffer, data, idx) diff --git a/test/observation.jl b/test/observation.jl index a0369bc..7a4a7e5 100644 --- a/test/observation.jl +++ b/test/observation.jl @@ -190,4 +190,11 @@ end @test getobs!((nothing,xbuf),(Xs,X), 3:4) == (getobs(Xs,3:4),xbuf) @test xbuf == getobs(X,3:4) end + + @testset "tables" begin + df = DataFrame(a=[1,2,3], y=["a","b","c"]) + @test getobs(df, 1) == df[:, 1] + @test getobs(df, 2:3) == df[:, 2:3] + @test numobs(df) == 3 + end end From fc704175843b994bc55fea6ed11afebd9890f846 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 16 Oct 2022 12:12:26 +0200 Subject: [PATCH 2/5] use SimpleTraits.jl --- Project.toml | 2 ++ src/MLUtils.jl | 6 ++++++ src/observation.jl | 12 ++++++++++-- test/observation.jl | 4 ++-- test/runtests.jl | 2 ++ 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 6f146be..9666236 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" @@ -22,6 +23,7 @@ DataAPI = "1.0" DelimitedFiles = "1.0" FLoops = "0.2" FoldsThreads = "0.1" +SimpleTraits = "0.9" ShowCases = "0.1" StatsBase = "0.33" Tables = "1.10" diff --git a/src/MLUtils.jl b/src/MLUtils.jl index bc10837..40c6aa5 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -9,12 +9,18 @@ using FoldsThreads: TaskPoolEx import StatsBase: sample using Transducers using Tables +using DataAPI using Base: @propagate_inbounds using Random: AbstractRNG, shuffle!, GLOBAL_RNG, rand!, randn! import ChainRulesCore: rrule using ChainRulesCore: @non_differentiable, unthunk, AbstractZero, NoTangent, ZeroTangent, ProjectTo +using SimpleTraits + +@traitdef IsTable{X} +@traitimpl IsTable{X} <- Tables.istable(X) + include("observation.jl") export numobs, diff --git a/src/observation.jl b/src/observation.jl index 1c48f7f..3c500b5 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -16,7 +16,9 @@ See also [`getobs`](@ref) function numobs end # Generic Fallbacks -numobs(data) = Tables.istable(data) ? DataAPI.nrow(x) : length(data) +@traitfn numobs(data::X) where {X; IsTable{X}} = DataAPI.nrow(data) +@traitfn numobs(data::X) where {X; !IsTable{X}} = length(data) + """ getobs(data, [idx]) @@ -46,7 +48,11 @@ function getobs end # Generic Fallbacks getobs(data) = data -getobs(data, idx) = Tables.istable(data) ? Tables.subset(data, idx) : data[idx] +# getobs(data, idx) = data[idx] + +@traitfn getobs(data::X, idx) where {X; IsTable{X}} = Tables.subset(data, idx, viewhint=false) +@traitfn getobs(data::X, idx) where {X; !IsTable{X}} = data[idx] + """ getobs!(buffer, data, idx) @@ -161,3 +167,5 @@ function getobs!(buffers, data::Dict, i) return buffers end + + diff --git a/test/observation.jl b/test/observation.jl index 7a4a7e5..1e3d098 100644 --- a/test/observation.jl +++ b/test/observation.jl @@ -193,8 +193,8 @@ end @testset "tables" begin df = DataFrame(a=[1,2,3], y=["a","b","c"]) - @test getobs(df, 1) == df[:, 1] - @test getobs(df, 2:3) == df[:, 2:3] + @test getobs(df, 1) == df[1,:] + @test getobs(df, 2:3) == df[2:3,:] @test numobs(df) == 3 end end diff --git a/test/runtests.jl b/test/runtests.jl index 5eee979..a88d66f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using MLUtils using MLUtils.Datasets using MLUtils: RingBuffer, eachobsparallel +using MLUtils: flatten, stack, unstack # also exported by DataFrames.jl using SparseArrays using Random, Statistics using Test @@ -9,6 +10,7 @@ using FoldsThreads: TaskPoolEx using ChainRulesTestUtils: test_rrule using Zygote: ZygoteRuleConfig using ChainRulesCore: rrule_via_ad +using DataFrames showcompact(io, x) = show(IOContext(io, :compact => true), x) From 310e8b98b63962c53f63c8a4220b63e4baa4783c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 17 Oct 2022 06:04:19 +0200 Subject: [PATCH 3/5] more tests --- src/observation.jl | 2 +- test/observation.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/observation.jl b/src/observation.jl index 3c500b5..1a39ded 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -47,8 +47,8 @@ See also [`getobs!`](@ref) and [`numobs`](@ref) function getobs end # Generic Fallbacks + getobs(data) = data -# getobs(data, idx) = data[idx] @traitfn getobs(data::X, idx) where {X; IsTable{X}} = Tables.subset(data, idx, viewhint=false) @traitfn getobs(data::X, idx) where {X; !IsTable{X}} = data[idx] diff --git a/test/observation.jl b/test/observation.jl index 1e3d098..7907f7f 100644 --- a/test/observation.jl +++ b/test/observation.jl @@ -193,6 +193,7 @@ end @testset "tables" begin df = DataFrame(a=[1,2,3], y=["a","b","c"]) + @test getobs(df) == df @test getobs(df, 1) == df[1,:] @test getobs(df, 2:3) == df[2:3,:] @test numobs(df) == 3 From dd7e856e064614b0ab4c457bb80b4cf3cab122ad Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 23 Oct 2022 10:50:05 +0200 Subject: [PATCH 4/5] docstring --- src/observation.jl | 44 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/src/observation.jl b/src/observation.jl index 1a39ded..677bb2e 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -3,14 +3,19 @@ Return the total number of observations contained in `data`. -If `data` does not have `numobs` defined, then this function -falls back to `length(data)`. +If `data` does not have `numobs` defined, +then in the case of `Tables.table(data) == true` +returns the number of rows, otherwise returns `length(data)`. + Authors of custom data containers should implement `Base.length` for their type instead of `numobs`. `numobs` should only be implemented for types where there is a difference between `numobs` and `Base.length` (such as multi-dimensional arrays). +`getobs` supports by default nested combinations of array, tuple, +named tuples, and dictionaries. + See also [`getobs`](@ref) """ function numobs end @@ -23,12 +28,15 @@ function numobs end """ getobs(data, [idx]) -Return the observations corresponding to the observation-index `idx`. +Return the observations corresponding to the observation index `idx`. Note that `idx` can be any type as long as `data` has defined -`getobs` for that type. +`getobs` for that type. If `idx` is not provided, then materialize +all observations in `data`. + +If `data` does not have `getobs` defined, +then in the case of `Tables.table(data) == true` +returns the row(s) in position `idx`, otherwise returns `data[idx]`. -If `data` does not have `getobs` defined, then this function -falls back to `data[idx]`. Authors of custom data containers should implement `Base.getindex` for their type instead of `getobs`. `getobs` should only be implemented for types where there is a @@ -42,7 +50,27 @@ Every author behind some custom data container can make this decision themselves. The output should be consistent when `idx` is a scalar vs vector. -See also [`getobs!`](@ref) and [`numobs`](@ref) +`getobs` supports by default nested combinations of array, tuple, +named tuples, and dictionaries. + +See also [`getobs!`](@ref) and [`numobs`](@ref). + +# Examples + +```jldoctest +# named tuples +x = (a = [1, 2, 3], b = rand(6, 3)) + +getobs(x, 2) == (a = 2, b = x.b[:, 2]) +getobs(x, [1, 3]) == (a = [1, 3], b = x.b[:, [1, 3]]) + + +# dictionaries +x = Dict(:a => [1, 2, 3], :b => rand(6, 3)) + +getobs(x, 2) == Dict(:a => 2, :b => x[:b][:, 2]) +getobs(x, [1, 3]) == Dict(:a => [1, 3], :b => x[:b][:, [1, 3]]) +``` """ function getobs end @@ -67,6 +95,8 @@ method is provided for the type of `data`, then `buffer` will be because the type of `data` may not lend itself to the concept of `copy!`. Thus, supporting a custom `getobs!` is optional and not required. + +See also [`getobs`](@ref) and [`numobs`](@ref). """ function getobs! end # getobs!(buffer, data) = getobs(data) From 429062baa7a5c6b4ae687d76ee9d9e8e9d2780f7 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 23 Oct 2022 12:11:02 +0200 Subject: [PATCH 5/5] more docs --- src/observation.jl | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/observation.jl b/src/observation.jl index 677bb2e..4abffbd 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -16,7 +16,35 @@ difference between `numobs` and `Base.length` `getobs` supports by default nested combinations of array, tuple, named tuples, and dictionaries. -See also [`getobs`](@ref) +See also [`getobs`](@ref). + +# Examples +```jldoctest + +# named tuples +x = (a = [1, 2, 3], b = rand(6, 3)) +numobs(x) == 3 + +# dictionaries +x = Dict(:a => [1, 2, 3], :b => rand(6, 3)) +numobs(x) == 3 +``` +All internal containers must have the same number of observations: +```juliarepl +julia> x = (a = [1, 2, 3, 4], b = rand(6, 3)); + +julia> numobs(x) +ERROR: DimensionMismatch: All data containers must have the same number of observations. +Stacktrace: + [1] _check_numobs_error() + @ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:163 + [2] _check_numobs + @ ~/.julia/dev/MLUtils/src/observation.jl:130 [inlined] + [3] numobs(data::NamedTuple{(:a, :b), Tuple{Vector{Int64}, Matrix{Float64}}}) + @ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:177 + [4] top-level scope + @ REPL[35]:1 +``` """ function numobs end