Skip to content

Commit 532aaa3

Browse files
committed
"generalize" definition of is_same_except and == for MLJType
1 parent 43116c3 commit 532aaa3

File tree

3 files changed

+157
-36
lines changed

3 files changed

+157
-36
lines changed

src/MLJModelInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ const MODEL_TRAITS = [
2323
:hyperparameter_types,
2424
:hyperparameter_ranges,
2525
:iteration_parameter,
26-
:supports_training_losses]
26+
:supports_training_losses,
27+
:deep_properties]
2728

2829
# ------------------------------------------------------------------------
2930
# Dependencies (ScientificTypes and StatisticalTraits have none)

src/equality.jl

Lines changed: 103 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,116 @@
1+
function _equal_to_depth_one(x1, x2)
2+
names = propertynames(x1)
3+
names === propertynames(x2) || return false
4+
for name in names
5+
getproperty(x1, name) == getproperty(x2, name) || return false
6+
end
7+
return true
8+
end
9+
10+
@doc """
11+
deep_properties(::Type{<:MLJType})
12+
13+
Given an `MLJType` subtype `M`, the value of this trait should be a
14+
tuple of any properties of `M` to be regarded as "deep".
15+
16+
When two instances of type `M` are to be tested for equality, in the
17+
sense of `==` or `is_same_except`, then the values of a "deep"
18+
property (whose values are assumed to be of composite type) are deemed
19+
to agree if all corresponding properties *of those property values*
20+
are `==`.
21+
22+
Any property of `M` whose values are themselves of `MLJType` are
23+
"deep" automatically, and should not be included in the trait return
24+
value.
25+
26+
See also [`is_same_except`](@ref)
27+
28+
### Example
29+
30+
Consider an `MLJType` subtype `Foo`, with a single field of
31+
type `Bar` which is *not* a subtype of `MLJType`:
32+
33+
mutable struct Bar
34+
x::Int
35+
end
36+
37+
mutable struct Foo <: MLJType
38+
bar::Bar
39+
end
40+
41+
Then the mutability of `Foo` implies `Foo(1) != Foo(1)` and so, by the
42+
definition `==` for `MLJType` objects (see [`is_same_except`](@ref))
43+
we have
44+
45+
Bar(Foo(1)) != Bar(Foo(1))
46+
47+
However after the declaration
48+
49+
MLJModelInterface.deep_properties(::Type{<:Foo}) = (:bar,)
50+
51+
We have
52+
53+
Bar(Foo(1)) == Bar(Foo(1))
54+
155
"""
2-
is_same_except(m1::MLJType, m2::MLJType, exceptions::Symbol...)
56+
StatisticalTraits.deep_properties
57+
58+
59+
"""
60+
is_same_except(m1, m2, exceptions::Symbol...; deep_properties=Symbol[])
61+
62+
If both `m1` and `m2` are of `MLJType`, return `true` if the
63+
following conditions all hold, and `false` otherwise:
364
4-
Returns `true` only the following conditions all hold:
65+
- `typeof(m1) === typeof(m2)`
566
6-
- `m1` and `m2` have the same type.
67+
- `propertynames(m1) === propertynames(m2)`
768
8-
- `m1` and `m2` have the same undefined fields.
69+
- with the exception of properties listed as `exceptions` or bound to
70+
an `AbstractRNG`, each pair of corresponding property values is
71+
either "equal" or both undefined.
972
10-
- Corresponding fields agree, or are listed as `exceptions`, or have
11-
`AbstractRNG` as values (one or both)
73+
The meaining of "equal" depends on the type of the property value:
1274
13-
Here "agree" is in the sense of "==", unless the objects are themselves of
14-
`MLJType`, in which case agreement is in the sense of `is_same_except` with
15-
no exceptions allowed.
75+
- values that are themselves of `MLJType` are "equal" if they are
76+
equal in the sense of `is_same_except` with no exceptions.
77+
78+
- values that are not of `MLJType` are "equal" if they are `==`.
79+
80+
In the special case of a "deep" property, "equal" has a different
81+
meaning; see [`deep_properties`](@ref)) for details.
82+
83+
If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`.
1684
17-
Note that Base.== is overloaded such that `m1 == m2` if and only if
18-
`is_same_except(m1, m2)`.
1985
"""
2086
is_same_except(x1, x2) = ==(x1, x2)
21-
function is_same_except(m1::M1, m2::M2,
22-
exceptions::Symbol...) where {M1<:MLJType,M2<:MLJType}
23-
if typeof(m1) != typeof(m2)
24-
return false
25-
end
26-
defined1 = filter(fieldnames(M1)|>collect) do fld
27-
isdefined(m1, fld) && !(fld in exceptions)
28-
end
29-
defined2 = filter(fieldnames(M1)|>collect) do fld
30-
isdefined(m2, fld) && !(fld in exceptions)
31-
end
32-
if defined1 != defined2
33-
return false
34-
end
35-
same_values = true
36-
for fld in defined1
37-
same_values = same_values &&
38-
(is_same_except(getfield(m1, fld), getfield(m2, fld)) ||
39-
getfield(m1, fld) isa AbstractRNG ||
40-
getfield(m2, fld) isa AbstractRNG)
87+
function is_same_except(m1::M1,
88+
m2::M2,
89+
exceptions::Symbol...) where {M1<:MLJType,M2<:MLJType}
90+
typeof(m1) === typeof(m2) || return false
91+
names = propertynames(m1)
92+
propertynames(m2) === names || return false
93+
94+
for name in names
95+
if !(name in exceptions)
96+
if !isdefined(m1, name)
97+
!isdefined(m2, name) || return false
98+
elseif isdefined(m2, name)
99+
if name in deep_properties(M1)
100+
_equal_to_depth_one(getproperty(m1,name),
101+
getproperty(m2, name)) || return false
102+
else
103+
(is_same_except(getproperty(m1, name),
104+
getproperty(m2, name)) ||
105+
getproperty(m1, name) isa AbstractRNG ||
106+
getproperty(m2, name) isa AbstractRNG) || return false
107+
end
108+
else
109+
return false
110+
end
111+
end
41112
end
42-
return same_values
113+
return true
43114
end
44115

45116
==(m1::M1, m2::M2) where {M1<:MLJType,M2<:MLJType} = is_same_except(m1, m2)

test/equality.jl

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,24 @@ mutable struct Foo <: MLJType
88
y::Int
99
end
1010

11-
mutable struct Bar <: MLJType
11+
mutable struct Bar{names} <: MLJType
1212
rng::AbstractRNG
13-
x::Int
14-
y::Int
13+
v::Tuple{Int,Int}
14+
Bar{names}(rng, x, y) where names =
15+
new{names}(rng, (x, y))
16+
end
17+
18+
Bar(rng, x, y) = Bar{(:x, :y)}(rng, x, y)
19+
20+
# overload `getproperty` so that components of `v` are accessed with
21+
# the names given in `names` (which will be (:x, :y) when using
22+
# the above constructor):
23+
Base.propertynames(::Bar{names}) where names = (:rng, names...)
24+
function Base.getproperty(b::Bar{names}, name::Symbol) where names
25+
name === :rng && return getfield(b, :rng)
26+
v = getfield(b, :v)
27+
name === names[1] && return v[1]
28+
return v[2]
1529
end
1630

1731
mutable struct Super <: MLJType
@@ -25,6 +39,35 @@ mutable struct Partial <: MLJType
2539
Partial(x) = new(x)
2640
end
2741

42+
mutable struct Sub
43+
x::Int
44+
end
45+
46+
mutable struct Deep
47+
x::Int
48+
s::Union{Sub,Int}
49+
end
50+
51+
mutable struct Super2 <: MLJType
52+
sub::Sub
53+
z::Int
54+
end
55+
56+
MLJModelInterface.deep_properties(::Type{<:Super2}) = (:sub,)
57+
58+
@testset "_equal_to_depth_one" begin
59+
d1 = Deep(1, 2)
60+
d2 = Deep(1, 2)
61+
@test MLJModelInterface._equal_to_depth_one(d1, d2)
62+
d2.x = 3
63+
@test !MLJModelInterface._equal_to_depth_one(d1, d2)
64+
65+
d1 = Deep(1, Sub(2))
66+
d2 = Deep(1, Sub(2))
67+
@test !MLJModelInterface._equal_to_depth_one(d1, d2)
68+
end
69+
70+
2871
@testset "equality for MLJType" begin
2972
f1 = Foo(MersenneTwister(7), 1, 2)
3073
f2 = Foo(MersenneTwister(8), 1, 2)
@@ -56,9 +99,15 @@ end
5699

57100
p1 = Partial(1)
58101
p2 = Partial(1)
102+
@test p1 == p2
59103
p2.y = [1,2]
60104
@test !(p1 == p2)
61105

106+
# test of "deep" properties
107+
s1 = Super2(Sub(1), 2)
108+
s2 = Super2(Sub(1), 2)
109+
@test s1 == s2
110+
62111
end
63112

64113
@testset "in(x, collection) for MLJType" begin

0 commit comments

Comments
 (0)