Skip to content

Commit a85c098

Browse files
numobs and getobs support for Tables.jl's tables (#124)
* support tables * use SimpleTraits.jl * more tests * docstring * more docs
1 parent e247fb5 commit a85c098

File tree

5 files changed

+102
-12
lines changed

5 files changed

+102
-12
lines changed

Project.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,41 @@
11
name = "MLUtils"
22
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
33
authors = ["Carlo Lucibello <[email protected]> and contributors"]
4-
version = "0.2.11"
4+
version = "0.2.12"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
89
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
910
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
1011
FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
14+
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
1315
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1416
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
17+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1518
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
1619

1720
[compat]
1821
ChainRulesCore = "1.0"
22+
DataAPI = "1.0"
1923
DelimitedFiles = "1.0"
2024
FLoops = "0.2"
2125
FoldsThreads = "0.1"
26+
SimpleTraits = "0.9"
2227
ShowCases = "0.1"
2328
StatsBase = "0.33"
29+
Tables = "1.10"
2430
Transducers = "0.4"
2531
julia = "1.6"
2632

2733
[extras]
2834
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
35+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
2936
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3037
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3138
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3239

3340
[targets]
34-
test = ["ChainRulesTestUtils", "SparseArrays", "Test", "Zygote"]
41+
test = ["ChainRulesTestUtils", "DataFrames", "SparseArrays", "Test", "Zygote"]

src/MLUtils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,19 @@ using FLoops.Transducers: Executor, ThreadedEx
88
using FoldsThreads: TaskPoolEx
99
import StatsBase: sample
1010
using Transducers
11+
using Tables
12+
using DataAPI
1113
using Base: @propagate_inbounds
1214
using Random: AbstractRNG, shuffle!, GLOBAL_RNG, rand!, randn!
1315
import ChainRulesCore: rrule
1416
using ChainRulesCore: @non_differentiable, unthunk, AbstractZero,
1517
NoTangent, ZeroTangent, ProjectTo
1618

19+
using SimpleTraits
20+
21+
@traitdef IsTable{X}
22+
@traitimpl IsTable{X} <- Tables.istable(X)
23+
1724

1825
include("observation.jl")
1926
export numobs,

src/observation.jl

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,68 @@
33
44
Return the total number of observations contained in `data`.
55
6-
If `data` does not have `numobs` defined, then this function
7-
falls back to `length(data)`.
6+
If `data` does not have `numobs` defined,
7+
then in the case of `Tables.table(data) == true`
8+
returns the number of rows, otherwise returns `length(data)`.
9+
810
Authors of custom data containers should implement
911
`Base.length` for their type instead of `numobs`.
1012
`numobs` should only be implemented for types where there is a
1113
difference between `numobs` and `Base.length`
1214
(such as multi-dimensional arrays).
1315
14-
See also [`getobs`](@ref)
16+
`getobs` supports by default nested combinations of array, tuple,
17+
named tuples, and dictionaries.
18+
19+
See also [`getobs`](@ref).
20+
21+
# Examples
22+
```jldoctest
23+
24+
# named tuples
25+
x = (a = [1, 2, 3], b = rand(6, 3))
26+
numobs(x) == 3
27+
28+
# dictionaries
29+
x = Dict(:a => [1, 2, 3], :b => rand(6, 3))
30+
numobs(x) == 3
31+
```
32+
All internal containers must have the same number of observations:
33+
```juliarepl
34+
julia> x = (a = [1, 2, 3, 4], b = rand(6, 3));
35+
36+
julia> numobs(x)
37+
ERROR: DimensionMismatch: All data containers must have the same number of observations.
38+
Stacktrace:
39+
[1] _check_numobs_error()
40+
@ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:163
41+
[2] _check_numobs
42+
@ ~/.julia/dev/MLUtils/src/observation.jl:130 [inlined]
43+
[3] numobs(data::NamedTuple{(:a, :b), Tuple{Vector{Int64}, Matrix{Float64}}})
44+
@ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:177
45+
[4] top-level scope
46+
@ REPL[35]:1
47+
```
1548
"""
1649
function numobs end
1750

1851
# Generic Fallbacks
19-
numobs(data) = length(data)
52+
@traitfn numobs(data::X) where {X; IsTable{X}} = DataAPI.nrow(data)
53+
@traitfn numobs(data::X) where {X; !IsTable{X}} = length(data)
54+
2055

2156
"""
2257
getobs(data, [idx])
2358
24-
Return the observations corresponding to the observation-index `idx`.
59+
Return the observations corresponding to the observation index `idx`.
2560
Note that `idx` can be any type as long as `data` has defined
26-
`getobs` for that type.
61+
`getobs` for that type. If `idx` is not provided, then materialize
62+
all observations in `data`.
63+
64+
If `data` does not have `getobs` defined,
65+
then in the case of `Tables.table(data) == true`
66+
returns the row(s) in position `idx`, otherwise returns `data[idx]`.
2767
28-
If `data` does not have `getobs` defined, then this function
29-
falls back to `data[idx]`.
3068
Authors of custom data containers should implement
3169
`Base.getindex` for their type instead of `getobs`.
3270
`getobs` should only be implemented for types where there is a
@@ -40,13 +78,37 @@ Every author behind some custom data container can make this
4078
decision themselves.
4179
The output should be consistent when `idx` is a scalar vs vector.
4280
43-
See also [`getobs!`](@ref) and [`numobs`](@ref)
81+
`getobs` supports by default nested combinations of array, tuple,
82+
named tuples, and dictionaries.
83+
84+
See also [`getobs!`](@ref) and [`numobs`](@ref).
85+
86+
# Examples
87+
88+
```jldoctest
89+
# named tuples
90+
x = (a = [1, 2, 3], b = rand(6, 3))
91+
92+
getobs(x, 2) == (a = 2, b = x.b[:, 2])
93+
getobs(x, [1, 3]) == (a = [1, 3], b = x.b[:, [1, 3]])
94+
95+
96+
# dictionaries
97+
x = Dict(:a => [1, 2, 3], :b => rand(6, 3))
98+
99+
getobs(x, 2) == Dict(:a => 2, :b => x[:b][:, 2])
100+
getobs(x, [1, 3]) == Dict(:a => [1, 3], :b => x[:b][:, [1, 3]])
101+
```
44102
"""
45103
function getobs end
46104

47105
# Generic Fallbacks
106+
48107
getobs(data) = data
49-
getobs(data, idx) = data[idx]
108+
109+
@traitfn getobs(data::X, idx) where {X; IsTable{X}} = Tables.subset(data, idx, viewhint=false)
110+
@traitfn getobs(data::X, idx) where {X; !IsTable{X}} = data[idx]
111+
50112

51113
"""
52114
getobs!(buffer, data, idx)
@@ -61,6 +123,8 @@ method is provided for the type of `data`, then `buffer` will be
61123
because the type of `data` may not lend itself to the concept
62124
of `copy!`. Thus, supporting a custom `getobs!` is optional
63125
and not required.
126+
127+
See also [`getobs`](@ref) and [`numobs`](@ref).
64128
"""
65129
function getobs! end
66130
# getobs!(buffer, data) = getobs(data)
@@ -161,3 +225,5 @@ function getobs!(buffers, data::Dict, i)
161225

162226
return buffers
163227
end
228+
229+

test/observation.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,12 @@ end
190190
@test getobs!((nothing,xbuf),(Xs,X), 3:4) == (getobs(Xs,3:4),xbuf)
191191
@test xbuf == getobs(X,3:4)
192192
end
193+
194+
@testset "tables" begin
195+
df = DataFrame(a=[1,2,3], y=["a","b","c"])
196+
@test getobs(df) == df
197+
@test getobs(df, 1) == df[1,:]
198+
@test getobs(df, 2:3) == df[2:3,:]
199+
@test numobs(df) == 3
200+
end
193201
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using MLUtils
22
using MLUtils.Datasets
33
using MLUtils: RingBuffer, eachobsparallel
4+
using MLUtils: flatten, stack, unstack # also exported by DataFrames.jl
45
using SparseArrays
56
using Random, Statistics
67
using Test
@@ -9,6 +10,7 @@ using FoldsThreads: TaskPoolEx
910
using ChainRulesTestUtils: test_rrule
1011
using Zygote: ZygoteRuleConfig
1112
using ChainRulesCore: rrule_via_ad
13+
using DataFrames
1214

1315
showcompact(io, x) = show(IOContext(io, :compact => true), x)
1416

0 commit comments

Comments
 (0)