Skip to content

Commit e1ed5b4

Browse files
authored
Merge pull request #27 from alan-turing-institute/dev
For 0.2.0 release
2 parents 8a4a3e8 + f1a4c67 commit e1ed5b4

File tree

7 files changed

+171
-8
lines changed

7 files changed

+171
-8
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "0.1.8"
4+
version = "0.2.0"
55

66
[deps]
7+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
78
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
89

910
[compat]

src/MLJModelInterface.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module MLJModelInterface
22

33
# ------------------------------------------------------------------------
4-
# Dependency (ScientificTypes itself does not have dependencies)
4+
# Dependencies (ScientificTypes itself does not have dependencies)
55
using ScientificTypes
6+
using Random
67

78
# ------------------------------------------------------------------------
89
# exports
@@ -36,12 +37,21 @@ export input_scitype, output_scitype, target_scitype,
3637
export matrix, int, classes, decoder, table,
3738
nrows, selectrows, selectcols, select
3839

40+
# equality
41+
export is_same_except
42+
3943
# re-exports from ScientificTypes
4044
export Scientific, Found, Unknown, Known, Finite, Infinite,
4145
OrderedFactor, Multiclass, Count, Continuous, Textual,
4246
Binary, ColorImage, GrayImage, Image, Table
4347
export scitype, scitype_union, elscitype, nonmissing, trait
4448

49+
# ------------------------------------------------------------------------
50+
# To be extended
51+
52+
import Base.==
53+
import Base: in, isequal
54+
#
4555
# ------------------------------------------------------------------------
4656
# Mode trick
4757

@@ -79,12 +89,13 @@ abstract type Static <: Unsupervised end
7989
# includes
8090

8191
include("utils.jl")
82-
8392
include("data_utils.jl")
8493
include("metadata_utils.jl")
8594

8695
include("model_traits.jl")
8796
include("model_def.jl")
8897
include("model_api.jl")
98+
include("equality.jl")
99+
89100

90101
end # module

src/equality.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
is_same_except(m1::MLJType, m2::MLJType, exceptions::Symbol...)
3+
4+
Returns `true` only the following conditions all hold:
5+
6+
- `m1` and `m2` have the same type.
7+
8+
- `m1` and `m2` have the same undefined fields.
9+
10+
- Corresponding fields agree, or are listed as `exceptions`, or have
11+
`AbstractRNG` as values (one or both)
12+
13+
Here "agree" is in the sense of "==", unless the objects are themselves of
14+
`MLJType`, in which case agreement is in the sense of `is_same_except` with
15+
no exceptions allowed.
16+
17+
Note that Base.== is overloaded such that `m1 == m2` if and only if
18+
`is_same_except(m1, m2)`.
19+
"""
20+
is_same_except(x1, x2) = ==(x1, x2)
21+
function is_same_except(m1::M1, m2::M2,
22+
exceptions::Symbol...) where {M1<:MLJType,M2<:MLJType}
23+
if typeof(m1) != typeof(m2)
24+
return false
25+
end
26+
defined1 = filter(fieldnames(M1)|>collect) do fld
27+
isdefined(m1, fld) && !(fld in exceptions)
28+
end
29+
defined2 = filter(fieldnames(M1)|>collect) do fld
30+
isdefined(m2, fld) && !(fld in exceptions)
31+
end
32+
if defined1 != defined2
33+
return false
34+
end
35+
same_values = true
36+
for fld in defined1
37+
same_values = same_values &&
38+
(is_same_except(getfield(m1, fld), getfield(m2, fld)) ||
39+
getfield(m1, fld) isa AbstractRNG ||
40+
getfield(m2, fld) isa AbstractRNG)
41+
end
42+
return same_values
43+
end
44+
45+
==(m1::M1, m2::M2) where {M1<:MLJType,M2<:MLJType} = is_same_except(m1, m2)
46+
47+
# for using `replace` or `replace!` on collections of MLJType objects
48+
# (eg, Model objects in a learning network) we need a stricter
49+
# equality and a corresponding definition of `in`.
50+
Base.isequal(m1::MLJType, m2::MLJType) = (m1 === m2)
51+
52+
# Note: To prevent julia crash, it seems we need to annotate the type
53+
# of itr:
54+
function special_in(x, itr)::Union{Bool,Missing}
55+
for y in itr
56+
ismissing(y) && return missing
57+
y === x && return true
58+
end
59+
return false
60+
end
61+
Base.in(x::MLJType, itr::Set) = special_in(x, itr)
62+
Base.in(x::MLJType, itr::AbstractVector) = special_in(x, itr)
63+
Base.in(x::MLJType, itr::Tuple) = special_in(x, itr)

src/model_api.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,10 @@ function predict_median end
3131
function transform end
3232
function inverse_transform end
3333

34+
# models can optionally overload these for enable serialization in a
35+
# custom format:
36+
function save end
37+
function restore end
38+
3439
# operations implemented by some meta-models:
3540
function evaluate end

src/model_traits.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@ const MODEL_TRAITS = [
1010
:prediction_type, :implemented_methods, :hyperparameters,
1111
:hyperparameter_types, :hyperparameter_ranges]
1212

13-
const SUPERVISED_TRAITS = setdiff(MODEL_TRAITS, [:output_scitype])
14-
15-
const UNSUPERVISED_TRAITS = setdiff(MODEL_TRAITS,
16-
[:target_scitype, :prediction_type, :supports_weights])
17-
1813
for trait in MODEL_TRAITS
1914
ex = quote
2015
$trait(x) = $trait(typeof(x))

test/equality.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using Random
2+
using MLJModelInterface
3+
using Test
4+
5+
mutable struct Foo <: MLJType
6+
rng::AbstractRNG
7+
x::Int
8+
y::Int
9+
end
10+
11+
mutable struct Bar <: MLJType
12+
rng::AbstractRNG
13+
x::Int
14+
y::Int
15+
end
16+
17+
mutable struct Super <: MLJType
18+
sub::Foo
19+
z::Int
20+
end
21+
22+
mutable struct Partial <: MLJType
23+
x::Int
24+
y::Vector{Int}
25+
Partial(x) = new(x)
26+
end
27+
28+
@testset "equality for MLJType" begin
29+
f1 = Foo(MersenneTwister(7), 1, 2)
30+
f2 = Foo(MersenneTwister(8), 1, 2)
31+
32+
@test f1.rng != f2.rng
33+
@test f1 == f2
34+
f1.x = 10
35+
@test f1 != f2
36+
b = Bar(MersenneTwister(7), 1, 2)
37+
@test f2 != b
38+
39+
@test is_same_except(f1, f2, :x)
40+
f1.y = 20
41+
@test f1 != f2
42+
@test is_same_except(f1, f2, :x, :y)
43+
44+
f1 = Foo(MersenneTwister(7), 1, 2)
45+
f2 = Foo(MersenneTwister(8), 1, 2)
46+
s1 = Super(f1, 20)
47+
s2 = Super(f2, 20)
48+
@test s1 == s2
49+
s2.sub.x = 10
50+
@test f1 != f2
51+
52+
@test !(f1 == Super(f1, 4))
53+
54+
@test !(isequal(Foo(MersenneTwister(1), 1, 2),
55+
Foo(MersenneTwister(1), 1, 2)))
56+
57+
p1 = Partial(1)
58+
p2 = Partial(1)
59+
p2.y = [1,2]
60+
@test !(p1 == p2)
61+
62+
end
63+
64+
@testset "in(x, collection) for MLJType" begin
65+
f1 = Foo(MersenneTwister(7), 1, 2)
66+
f2 = Foo(MersenneTwister(7), 1, 2)
67+
f3 = Super(f1, 20)
68+
69+
tv = (f1, f3)
70+
tk = (f2, f3)
71+
tw = (f3, f3)
72+
v = [tv...]
73+
k = [tk...]
74+
w = [tw...]
75+
@test f1 in tv
76+
@test !(f1 in tk)
77+
@test !(f1 in tw)
78+
@test f1 in v
79+
@test !(f1 in k)
80+
@test !(f1 in w)
81+
@test f1 in Set(v)
82+
@test !(f1 in Set(k))
83+
@test !(f1 in Set(w))
84+
85+
end
86+
87+
true

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ include("metadata_utils.jl")
1818
include("model_def.jl")
1919
include("model_api.jl")
2020
include("model_traits.jl")
21+
include("equality.jl")

0 commit comments

Comments
 (0)