Skip to content

Commit d8e8196

Browse files
authored
Add SciML style formatter to MLDatasets.jl (#205)
* Handle missing values * SciML style formatter * Formatted files * Run frontmatter
1 parent 42e6f5e commit d8e8196

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+1555
-1427
lines changed

.JuliaFormatter.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
style = "sciml"

.github/workflows/FormatCheck.yml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
name: format-check
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- master
7+
push:
8+
branches:
9+
- master
10+
11+
jobs:
12+
build:
13+
runs-on: ${{ matrix.os }}
14+
strategy:
15+
fail-fast: false
16+
matrix:
17+
version:
18+
- '1.6' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
19+
os:
20+
- ubuntu-latest
21+
arch:
22+
- x64
23+
steps:
24+
- uses: actions/checkout@v2
25+
- uses: julia-actions/setup-julia@v1
26+
with:
27+
version: ${{ matrix.version }}
28+
arch: ${{ matrix.arch }}
29+
- name: Install JuliaFormatter and Format
30+
run: |
31+
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
32+
julia -e 'using JuliaFormatter; format(".", verbose=true)'
33+
- name: Format check
34+
run: |
35+
julia -e '
36+
out = Cmd(`git diff`) |> read |> String
37+
if out == ""
38+
exit(0)
39+
else
40+
@error "Some files have not been formatted !!!"
41+
write(stdout, out)
42+
exit(1)
43+
end'

docs/make.jl

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,24 @@ using Documenter, MLDatasets
66
# Build documentation.
77
# ====================
88

9-
makedocs(
10-
modules = [MLDatasets],
11-
doctest = true,
12-
clean = false,
13-
sitename = "MLDatasets.jl",
14-
format = Documenter.HTML(
15-
canonical = "https://juliadata.github.io/MLDatasets.jl/stable/",
16-
assets = ["assets/favicon.ico"],
17-
prettyurls = get(ENV, "CI", nothing) == "true",
18-
collapselevel=3,
19-
),
20-
21-
authors = "Hiroyuki Shindo, Christof Stocker, Carlo Lucibello",
22-
23-
pages = Any[
24-
"Home" => "index.md",
25-
"Datasets" => Any[
26-
"Graphs" => "datasets/graphs.md",
27-
"Meshes" => "datasets/meshes.md",
28-
"Miscellaneous" => "datasets/misc.md",
29-
"Text" => "datasets/text.md",
30-
"Vision" => "datasets/vision.md",
31-
],
32-
"Creating Datasets" => Any["containers/overview.md"], # still experimental
33-
"LICENSE.md",
34-
],
35-
strict = true,
36-
checkdocs = :exports
37-
)
9+
makedocs(modules = [MLDatasets],
10+
doctest = true,
11+
clean = false,
12+
sitename = "MLDatasets.jl",
13+
format = Documenter.HTML(canonical = "https://juliadata.github.io/MLDatasets.jl/stable/",
14+
assets = ["assets/favicon.ico"],
15+
prettyurls = get(ENV, "CI", nothing) == "true",
16+
collapselevel = 3),
17+
authors = "Hiroyuki Shindo, Christof Stocker, Carlo Lucibello",
18+
pages = Any["Home" => "index.md",
19+
"Datasets" => Any["Graphs" => "datasets/graphs.md",
20+
"Meshes" => "datasets/meshes.md",
21+
"Miscellaneous" => "datasets/misc.md",
22+
"Text" => "datasets/text.md",
23+
"Vision" => "datasets/vision.md"],
24+
"Creating Datasets" => Any["containers/overview.md"], # still experimental
25+
"LICENSE.md"],
26+
strict = true,
27+
checkdocs = :exports)
3828

3929
deploydocs(repo = "github.com/JuliaML/MLDatasets.jl.git")

src/MLDatasets.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ include("require.jl") # export @require
2121
# In the other case instead, use `require import SomePkg` to force
2222
# the use to manually import.
2323

24-
@require import JSON3="0f8b85d8-7281-11e9-16c2-39a750bddbf1"
25-
@require import DataFrames="a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
26-
@require import ImageShow="4e3cecfd-b093-5904-9786-8bbb286a6a31"
27-
@require import Chemfiles="46823bd8-5fb3-5f92-9aa0-96921f3dd015"
24+
@require import JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
25+
@require import DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
26+
@require import ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
27+
@require import Chemfiles = "46823bd8-5fb3-5f92-9aa0-96921f3dd015"
2828

2929
# @lazy import NPZ # lazy imported by FileIO
30-
@lazy import Pickle="fbb45041-c46e-462f-888f-7c521cafbc2c"
31-
@lazy import MAT="23992714-dd62-5051-b70f-ba57cb901cac"
32-
@lazy import HDF5="f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
30+
@lazy import Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
31+
@lazy import MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
32+
@lazy import HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
3333
# @lazy import JLD2
3434

3535
export getobs, numobs # From MLUtils.jl
@@ -93,7 +93,6 @@ export Omniglot
9393
include("datasets/vision/svhn2.jl")
9494
export SVHN2
9595

96-
9796
## Text
9897

9998
include("datasets/text/ptblm.jl")

src/abstract_datasets.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@ Implements the following functionality:
99
"""
1010
abstract type AbstractDataset <: AbstractDataContainer end
1111

12-
1312
MLUtils.getobs(d::AbstractDataset) = d[:]
1413
MLUtils.getobs(d::AbstractDataset, i) = d[i]
1514

16-
function Base.show(io::IO, d::D) where D <: AbstractDataset
15+
function Base.show(io::IO, d::D) where {D <: AbstractDataset}
1716
print(io, "$(D.name.name)()")
1817
end
1918

20-
function Base.show(io::IO, ::MIME"text/plain", d::D) where D <: AbstractDataset
19+
function Base.show(io::IO, ::MIME"text/plain", d::D) where {D <: AbstractDataset}
2120
recur_io = IOContext(io, :compact => false)
2221

2322
print(io, "dataset $(D.name.name):") # if the type is parameterized don't print the parameters
@@ -38,7 +37,7 @@ function leftalign(s::AbstractString, n::Int)
3837
if m > n
3938
return s[1:n]
4039
else
41-
return s * repeat(" ", n-m)
40+
return s * repeat(" ", n - m)
4241
end
4342
end
4443

@@ -59,19 +58,23 @@ a `features` and a `targets` fields.
5958
"""
6059
abstract type SupervisedDataset <: AbstractDataset end
6160

62-
63-
Base.length(d::SupervisedDataset) = Tables.istable(d.features) ? numobs_table(d.features) :
64-
numobs((d.features, d.targets))
65-
61+
function Base.length(d::SupervisedDataset)
62+
Tables.istable(d.features) ? numobs_table(d.features) :
63+
numobs((d.features, d.targets))
64+
end
6665

6766
# We return named tuples
68-
Base.getindex(d::SupervisedDataset, ::Colon) = Tables.istable(d.features) ?
69-
(features = d.features, targets=d.targets) :
67+
function Base.getindex(d::SupervisedDataset, ::Colon)
68+
Tables.istable(d.features) ?
69+
(features = d.features, targets = d.targets) :
7070
getobs((; d.features, d.targets))
71+
end
7172

72-
Base.getindex(d::SupervisedDataset, i) = Tables.istable(d.features) ?
73-
(features = getobs_table(d.features, i), targets=getobs_table(d.targets, i)) :
73+
function Base.getindex(d::SupervisedDataset, i)
74+
Tables.istable(d.features) ?
75+
(features = getobs_table(d.features, i), targets = getobs_table(d.targets, i)) :
7476
getobs((; d.features, d.targets), i)
77+
end
7578

7679
"""
7780
UnsupervisedDataset <: AbstractDataset
@@ -81,13 +84,11 @@ Concrete dataset types inheriting from it must provide a `features` field.
8184
"""
8285
abstract type UnsupervisedDataset <: AbstractDataset end
8386

84-
8587
Base.length(d::UnsupervisedDataset) = numobs(d.features)
8688

8789
Base.getindex(d::UnsupervisedDataset, ::Colon) = getobs(d.features)
8890
Base.getindex(d::UnsupervisedDataset, i) = getobs(d.features, i)
8991

90-
9192
### DOCSTRING TEMPLATES ######################
9293

9394
# SUPERVISED TABLE
@@ -110,7 +111,6 @@ const METHODS_SUPERVISED_TABLE = """
110111
- `length(dataset)`: Number of observations.
111112
"""
112113

113-
114114
# SUPERVISED ARRAY DATASET
115115

116116
const ARGUMENTS_SUPERVISED_ARRAY = """

src/containers/cacheddataset.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ end
2828

2929
CachedDataset(source, cachesize::Int) = CachedDataset(source, 1:cachesize)
3030

31-
CachedDataset(source, cacheidx::AbstractVector{<:Integer} = 1:numobs(source)) =
31+
function CachedDataset(source, cacheidx::AbstractVector{<:Integer} = 1:numobs(source))
3232
CachedDataset(source, collect(cacheidx), make_cache(source, cacheidx))
33+
end
3334

3435
function Base.getindex(dataset::CachedDataset, i::Integer)
3536
_i = findfirst(==(i), dataset.cacheidx)

src/containers/filedataset.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,24 @@ Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`)
1717
Alternatively, specify a `dir` and collect all paths that match a glob `pattern`
1818
(recursively globbing by `depth`). The glob order determines the traversal order.
1919
"""
20-
struct FileDataset{F, T<:AbstractString} <: AbstractDataContainer
20+
struct FileDataset{F, T <: AbstractString} <: AbstractDataContainer
2121
loadfn::F
2222
paths::Vector{T}
2323
end
2424

2525
FileDataset(paths) = FileDataset(FileIO.load, paths)
26-
FileDataset(loadfn,
27-
dir::AbstractString,
28-
pattern::AbstractString = "*",
29-
depth = 4) = FileDataset(loadfn, rglob(pattern, string(dir), depth))
30-
FileDataset(dir::AbstractString, pattern::AbstractString = "*", depth = 4) =
26+
function FileDataset(loadfn,
27+
dir::AbstractString,
28+
pattern::AbstractString = "*",
29+
depth = 4)
30+
FileDataset(loadfn, rglob(pattern, string(dir), depth))
31+
end
32+
function FileDataset(dir::AbstractString, pattern::AbstractString = "*", depth = 4)
3133
FileDataset(FileIO.load, dir, pattern, depth)
34+
end
3235

3336
Base.getindex(dataset::FileDataset, i::Integer) = dataset.loadfn(dataset.paths[i])
34-
Base.getindex(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is)
37+
function Base.getindex(dataset::FileDataset, is::AbstractVector)
38+
map(Base.Fix1(getobs, dataset), is)
39+
end
3540
Base.length(dataset::FileDataset) = length(dataset.paths)

src/containers/hdf5dataset.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@ See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer.
1919
For array datasets, the last dimension is assumed to be the observation dimension.
2020
For scalar datasets, the stored value is returned by `getobs` for any index.
2121
"""
22-
struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer
22+
struct HDF5Dataset{T <: Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer
2323
fid::HDF5.File
2424
paths::T
2525
shapes::Vector{Tuple}
2626

27-
function HDF5Dataset(fid::HDF5.File, paths::T, shapes::Vector) where T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}
27+
function HDF5Dataset(fid::HDF5.File, paths::T,
28+
shapes::Vector) where {
29+
T <:
30+
Union{HDF5.Dataset, Vector{HDF5.Dataset}}}
2831
_check_hdf5_shapes(shapes) ||
2932
throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations."))
3033

@@ -33,11 +36,13 @@ struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractData
3336
end
3437

3538
HDF5Dataset(fid::HDF5.File, path::HDF5.Dataset) = HDF5Dataset(fid, path, [size(path)])
36-
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) =
39+
function HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset})
3740
HDF5Dataset(fid, paths, map(size, paths))
41+
end
3842
HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path])
39-
HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) =
43+
function HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString})
4044
HDF5Dataset(fid, map(p -> fid[p], paths))
45+
end
4146
HDF5Dataset(file::AbstractString, paths) = HDF5Dataset(h5open(file, "r"), paths)
4247

4348
_getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset)
@@ -46,10 +51,12 @@ function _getobs_hdf5(dataset::HDF5.Dataset, shape, i)
4651

4752
return dataset[I..., i]
4853
end
49-
Base.getindex(dataset::HDF5Dataset{HDF5.Dataset}, i) =
54+
function Base.getindex(dataset::HDF5Dataset{HDF5.Dataset}, i)
5055
_getobs_hdf5(dataset.paths, only(dataset.shapes), i)
51-
Base.getindex(dataset::HDF5Dataset{<:Vector}, i) =
56+
end
57+
function Base.getindex(dataset::HDF5Dataset{<:Vector}, i)
5258
Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes))
59+
end
5360
Base.length(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes)))
5461

5562
"""

src/containers/jld2dataset.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on
1010
each dataset in `paths`.
1111
See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer.
1212
"""
13-
struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer
13+
struct JLD2Dataset{T <: JLD2.JLDFile, S <: Tuple} <: AbstractDataContainer
1414
fid::T
1515
paths::S
1616

@@ -27,7 +27,9 @@ end
2727
JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,))
2828
JLD2Dataset(file::AbstractString, paths) = JLD2Dataset(jldopen(file, "r"), paths)
2929

30-
Base.getindex(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i)
30+
function Base.getindex(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i)
31+
getobs(only(dataset.paths), i)
32+
end
3133
Base.getindex(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths)
3234
Base.length(dataset::JLD2Dataset) = numobs(dataset.paths[1])
3335

src/containers/tabledataset.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ end
2121
TableDataset(table::T) where {T} = TableDataset{T}(table)
2222
TableDataset(path::AbstractString) = TableDataset(read_csv(path))
2323

24-
2524
# slow accesses based on Tables.jl
2625
_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1)))
2726
function _getobs_column(x, i)
@@ -55,7 +54,6 @@ end
5554
Base.getindex(dataset::TableDataset, i) = getobs_table(dataset.table, i)
5655
Base.length(dataset::TableDataset) = numobs_table(dataset.table)
5756

58-
5957
# fast access for DataFrame
6058
# Base.getindex(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :]
6159
# Base.length(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table)

0 commit comments

Comments
 (0)