Skip to content

Commit 7126247

Browse files
committed
adding initial logic
1 parent 64f0a79 commit 7126247

File tree

4 files changed

+80
-214
lines changed

4 files changed

+80
-214
lines changed

Manifest.toml

Lines changed: 1 addition & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -1,207 +1,6 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
[[Base64]]
4-
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
5-
6-
[[CategoricalArrays]]
7-
deps = ["Compat", "DataAPI", "Future", "JSON", "Missings", "Printf", "Reexport", "Statistics", "Unicode"]
8-
git-tree-sha1 = "23d7324164c89638c18f6d7f90d972fa9c4fa9fb"
9-
uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
10-
version = "0.7.7"
11-
12-
[[ColorTypes]]
13-
deps = ["FixedPointNumbers", "Random"]
14-
git-tree-sha1 = "b9de8dc6106e09c79f3f776c27c62360d30e5eb8"
15-
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
16-
version = "0.9.1"
17-
18-
[[Compat]]
19-
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
20-
git-tree-sha1 = "3819f476b6b37ef8ea837070ed831b4ebadfa1e9"
21-
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
22-
version = "3.2.0"
23-
24-
[[Crayons]]
25-
git-tree-sha1 = "cb7a62895da739fe5bb43f1a26d4292baf4b3dc0"
26-
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
27-
version = "4.0.1"
28-
29-
[[DataAPI]]
30-
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
31-
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
32-
version = "1.1.0"
33-
34-
[[DataValueInterfaces]]
35-
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
36-
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
37-
version = "1.0.0"
38-
39-
[[Dates]]
40-
deps = ["Printf"]
41-
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
42-
43-
[[DelimitedFiles]]
44-
deps = ["Mmap"]
45-
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
46-
47-
[[Distributed]]
48-
deps = ["Random", "Serialization", "Sockets"]
49-
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
50-
51-
[[FixedPointNumbers]]
52-
git-tree-sha1 = "4aaea64dd0c30ad79037084f8ca2b94348e65eaa"
53-
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
54-
version = "0.7.1"
55-
56-
[[Formatting]]
57-
deps = ["Printf"]
58-
git-tree-sha1 = "a0c901c29c0e7c763342751c0a94211d56c0de5c"
59-
uuid = "59287772-0a20-5a39-b81b-1366585eb4c0"
60-
version = "0.4.1"
61-
62-
[[Future]]
63-
deps = ["Random"]
64-
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
65-
66-
[[InteractiveUtils]]
67-
deps = ["Markdown"]
68-
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
69-
70-
[[IteratorInterfaceExtensions]]
71-
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
72-
uuid = "82899510-4779-5014-852e-03e436cf321d"
73-
version = "1.0.0"
74-
75-
[[JSON]]
76-
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
77-
git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
78-
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
79-
version = "0.21.0"
80-
81-
[[LibGit2]]
82-
deps = ["Printf"]
83-
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
84-
85-
[[Libdl]]
86-
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
87-
88-
[[LinearAlgebra]]
89-
deps = ["Libdl"]
90-
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
91-
92-
[[Logging]]
93-
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
94-
95-
[[MLJScientificTypes]]
96-
deps = ["CategoricalArrays", "ColorTypes", "PrettyTables", "ScientificTypes", "Tables"]
97-
path = "/Users/tlienart/.julia/dev/MLJScientificTypes.jl"
98-
uuid = "2e2323e0-db8b-457b-ae0d-bdfb3bc63afd"
99-
version = "0.1.0"
100-
101-
[[Markdown]]
102-
deps = ["Base64"]
103-
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
104-
105-
[[Missings]]
106-
deps = ["DataAPI"]
107-
git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5"
108-
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
109-
version = "0.4.3"
110-
111-
[[Mmap]]
112-
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
113-
114-
[[OrderedCollections]]
115-
deps = ["Random", "Serialization", "Test"]
116-
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
117-
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
118-
version = "1.1.0"
119-
120-
[[Parameters]]
121-
deps = ["OrderedCollections"]
122-
git-tree-sha1 = "b62b2558efb1eef1fa44e4be5ff58a515c287e38"
123-
uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a"
124-
version = "0.12.0"
125-
126-
[[Parsers]]
127-
deps = ["Dates", "Test"]
128-
git-tree-sha1 = "0139ba59ce9bc680e2925aec5b7db79065d60556"
129-
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
130-
version = "0.3.10"
131-
132-
[[Pkg]]
133-
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
134-
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
135-
136-
[[PrettyTables]]
137-
deps = ["Crayons", "Formatting", "Parameters", "Reexport", "Tables"]
138-
git-tree-sha1 = "a98edb4f57f236e649599efa68b5e78c43cb51e1"
139-
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
140-
version = "0.8.0"
141-
142-
[[Printf]]
143-
deps = ["Unicode"]
144-
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
145-
146-
[[REPL]]
147-
deps = ["InteractiveUtils", "Markdown", "Sockets"]
148-
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
149-
150-
[[Random]]
151-
deps = ["Serialization"]
152-
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
153-
154-
[[Reexport]]
155-
deps = ["Pkg"]
156-
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
157-
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
158-
version = "0.2.0"
159-
160-
[[SHA]]
161-
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
162-
1633
[[ScientificTypes]]
164-
path = "../ScientificTypes"
4+
path = "/Users/tlienart/.julia/dev/ScientificTypes"
1655
uuid = "321657f4-b219-11e9-178b-2701a2544e81"
1666
version = "0.6.0"
167-
168-
[[Serialization]]
169-
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
170-
171-
[[SharedArrays]]
172-
deps = ["Distributed", "Mmap", "Random", "Serialization"]
173-
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
174-
175-
[[Sockets]]
176-
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
177-
178-
[[SparseArrays]]
179-
deps = ["LinearAlgebra", "Random"]
180-
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
181-
182-
[[Statistics]]
183-
deps = ["LinearAlgebra", "SparseArrays"]
184-
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
185-
186-
[[TableTraits]]
187-
deps = ["IteratorInterfaceExtensions"]
188-
git-tree-sha1 = "b1ad568ba658d8cbb3b892ed5380a6f3e781a81e"
189-
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
190-
version = "1.0.0"
191-
192-
[[Tables]]
193-
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"]
194-
git-tree-sha1 = "aaed7b3b00248ff6a794375ad6adf30f30ca5591"
195-
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
196-
version = "0.2.11"
197-
198-
[[Test]]
199-
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
200-
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
201-
202-
[[UUIDs]]
203-
deps = ["Random", "SHA"]
204-
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
205-
206-
[[Unicode]]
207-
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Project.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@ authors = ["Thibaut Lienart <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7-
MLJScientificTypes = "2e2323e0-db8b-457b-ae0d-bdfb3bc63afd"
7+
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
88

99
[compat]
10-
MLJScientificTypes = "^0.1"
10+
ScientificTypes = "^0.6"
11+
12+
[extras]
13+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
14+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15+
16+
[targets]
17+
test = ["Test", "Tables"]

src/MLJModelInterface.jl

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,46 @@ module MLJModelInterface
22

33
# ------------------------------------------------------------------------
44
# Dependency (note that ScientificTypes itself does not have dependencies)
5-
using ScientificTypes
5+
import ScientificTypes: trait
66

77
# ------------------------------------------------------------------------
8-
# Exports
9-
export Dummy, Live, get_interface_mode
8+
# Single export: matrix, everything else is qualified in MLJBase
109
export matrix
1110

1211
# ------------------------------------------------------------------------
13-
# Mode trick
1412

1513
abstract type Mode end
16-
struct Dummy <: Mode end
17-
struct Live <: Mode end
14+
struct LightInterface <: Mode end
15+
struct FullInterface <: Mode end
1816

19-
const INTERFACE_MODE = Ref{Mode}(Dummy())
17+
const INTERFACE_MODE = Ref{Mode}(LightInterface())
18+
19+
set_interface_mode(m::Mode) = (INTERFACE_MODE[] = m)
2020

2121
get_interface_mode() = INTERFACE_MODE[]
2222

23-
matrix(a...; kw...) = matrix(a...; interface_mode=get_interface_mode(), kw...)
23+
struct InterfaceError <: Exception
24+
m::String
25+
end
26+
27+
vtrait(X) = X |> trait |> Val
28+
29+
"""
30+
matrix(X; transpose=false)
31+
32+
If `X <: AbstractMatrix`, return `X` or `permutedims(X)` if `transpose=true`.
33+
If `X` is a Tables.jl compatible table source, convert `X` into a `Matrix`.
34+
"""
35+
matrix(X; kw...) = matrix(vtrait(X), X, get_interface_mode(); kw...)
36+
37+
matrix(::Val{:other}, X::AbstractMatrix, ::Mode; transpose=false) =
38+
transpose ? permutedims(X) : X
39+
40+
matrix(::Val{:other}, X, ::Mode; kw...) =
41+
throw(ArgumentError("Function `matrix` only supports AbstractMatrix or " *
42+
"containers implementing the Tables interface."))
2443

25-
matrix(a...; interface_mode::Mode=Dummy(), kw...) =
26-
error("Only `MLJModelInterface` loaded. Do `import MLJBase`.")
44+
matrix(::Val{:table}, X, ::LightInterface; kw...) =
45+
throw(InterfaceError("Only `MLJModelInterface` loaded. Import `MLJBase`."))
2746

2847
end # module

test/runtests.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using Test, MLJModelInterface, ScientificTypes
2+
using Tables
3+
4+
const M = MLJModelInterface
5+
6+
ScientificTypes.TRAIT_FUNCTION_GIVEN_NAME[:table] = Tables.istable
7+
8+
@testset "light-interface" begin
9+
M.set_interface_mode(M.LightInterface())
10+
@test M.get_interface_mode() isa M.LightInterface
11+
12+
# matrix object (:other)
13+
X = zeros(3, 4)
14+
mX = matrix(X)
15+
mtX = matrix(X; transpose=true)
16+
17+
@test mX === X
18+
@test mtX == permutedims(X)
19+
20+
# :other but not matrix
21+
X = (1, 2, 3, 4)
22+
@test_throws ArgumentError matrix(X)
23+
24+
# :table
25+
X = (x=[1,2,3], y=[1,2,3])
26+
@test M.vtrait(X) isa Val{:table}
27+
@test_throws M.InterfaceError matrix(X)
28+
end
29+
30+
@testset "full-interface" begin
31+
M.set_interface_mode(M.FullInterface())
32+
@test M.get_interface_mode() isa M.FullInterface
33+
34+
M.matrix(::Val{:table}, X, ::M.FullInterface; kw...) =
35+
Tables.matrix(X; kw...)
36+
37+
X = (x=[1,2,3], y=[1,2,3])
38+
mX = matrix(X)
39+
@test mX isa Matrix
40+
@test mX == hcat(X.x, X.y)
41+
end

0 commit comments

Comments
 (0)