|
| 1 | +function _equal_to_depth_one(x1, x2) |
| 2 | + names = propertynames(x1) |
| 3 | + names === propertynames(x2) || return false |
| 4 | + for name in names |
| 5 | + getproperty(x1, name) == getproperty(x2, name) || return false |
| 6 | + end |
| 7 | + return true |
| 8 | +end |
| 9 | + |
| 10 | +@doc """ |
| 11 | + deep_properties(::Type{<:MLJType}) |
| 12 | +
|
| 13 | +Given an `MLJType` subtype `M`, the value of this trait should be a |
| 14 | +tuple of any properties of `M` to be regarded as "deep". |
| 15 | +
|
| 16 | +When two instances of type `M` are to be tested for equality, in the |
| 17 | +sense of `==` or `is_same_except`, then the values of a "deep" |
| 18 | +property (whose values are assumed to be of composite type) are deemed |
| 19 | +to agree if all corresponding properties *of those property values* |
| 20 | +are `==`. |
| 21 | +
|
| 22 | +Any property of `M` whose values are themselves of `MLJType` are |
| 23 | +"deep" automatically, and should not be included in the trait return |
| 24 | +value. |
| 25 | +
|
| 26 | +See also [`is_same_except`](@ref) |
| 27 | +
|
| 28 | +### Example |
| 29 | +
|
| 30 | +Consider an `MLJType` subtype `Foo`, with a single field of |
| 31 | +type `Bar` which is *not* a subtype of `MLJType`: |
| 32 | +
|
| 33 | + mutable struct Bar |
| 34 | + x::Int |
| 35 | + end |
| 36 | +
|
| 37 | + mutable struct Foo <: MLJType |
| 38 | + bar::Bar |
| 39 | + end |
| 40 | +
|
| 41 | +Then the mutability of `Foo` implies `Foo(1) != Foo(1)` and so, by the |
| 42 | +definition `==` for `MLJType` objects (see [`is_same_except`](@ref)) |
| 43 | +we have |
| 44 | +
|
| 45 | + Bar(Foo(1)) != Bar(Foo(1)) |
| 46 | +
|
| 47 | +However after the declaration |
| 48 | +
|
| 49 | + MLJModelInterface.deep_properties(::Type{<:Foo}) = (:bar,) |
| 50 | +
|
| 51 | +We have |
| 52 | +
|
| 53 | + Bar(Foo(1)) == Bar(Foo(1)) |
| 54 | +
|
1 | 55 | """
|
2 |
| - is_same_except(m1::MLJType, m2::MLJType, exceptions::Symbol...) |
| 56 | +StatisticalTraits.deep_properties |
| 57 | + |
| 58 | + |
| 59 | +""" |
| 60 | + is_same_except(m1, m2, exceptions::Symbol...; deep_properties=Symbol[]) |
| 61 | +
|
| 62 | +If both `m1` and `m2` are of `MLJType`, return `true` if the |
| 63 | +following conditions all hold, and `false` otherwise: |
3 | 64 |
|
4 |
| -Returns `true` only the following conditions all hold: |
| 65 | +- `typeof(m1) === typeof(m2)` |
5 | 66 |
|
6 |
| -- `m1` and `m2` have the same type. |
| 67 | +- `propertynames(m1) === propertynames(m2)` |
7 | 68 |
|
8 |
| -- `m1` and `m2` have the same undefined fields. |
| 69 | +- with the exception of properties listed as `exceptions` or bound to |
| 70 | + an `AbstractRNG`, each pair of corresponding property values is |
| 71 | + either "equal" or both undefined. |
9 | 72 |
|
10 |
| -- Corresponding fields agree, or are listed as `exceptions`, or have |
11 |
| - `AbstractRNG` as values (one or both) |
| 73 | +The meaining of "equal" depends on the type of the property value: |
12 | 74 |
|
13 |
| -Here "agree" is in the sense of "==", unless the objects are themselves of |
14 |
| -`MLJType`, in which case agreement is in the sense of `is_same_except` with |
15 |
| -no exceptions allowed. |
| 75 | +- values that are themselves of `MLJType` are "equal" if they are |
| 76 | +equal in the sense of `is_same_except` with no exceptions. |
| 77 | +
|
| 78 | +- values that are not of `MLJType` are "equal" if they are `==`. |
| 79 | +
|
| 80 | +In the special case of a "deep" property, "equal" has a different |
| 81 | +meaning; see [`deep_properties`](@ref)) for details. |
| 82 | +
|
| 83 | +If `m1` or `m2` are not `MLJType` objects, then return `==(m1, m2)`. |
16 | 84 |
|
17 |
| -Note that Base.== is overloaded such that `m1 == m2` if and only if |
18 |
| -`is_same_except(m1, m2)`. |
19 | 85 | """
|
20 | 86 | is_same_except(x1, x2) = ==(x1, x2)
|
21 |
| -function is_same_except(m1::M1, m2::M2, |
22 |
| - exceptions::Symbol...) where {M1<:MLJType,M2<:MLJType} |
23 |
| - if typeof(m1) != typeof(m2) |
24 |
| - return false |
25 |
| - end |
26 |
| - defined1 = filter(fieldnames(M1)|>collect) do fld |
27 |
| - isdefined(m1, fld) && !(fld in exceptions) |
28 |
| - end |
29 |
| - defined2 = filter(fieldnames(M1)|>collect) do fld |
30 |
| - isdefined(m2, fld) && !(fld in exceptions) |
31 |
| - end |
32 |
| - if defined1 != defined2 |
33 |
| - return false |
34 |
| - end |
35 |
| - same_values = true |
36 |
| - for fld in defined1 |
37 |
| - same_values = same_values && |
38 |
| - (is_same_except(getfield(m1, fld), getfield(m2, fld)) || |
39 |
| - getfield(m1, fld) isa AbstractRNG || |
40 |
| - getfield(m2, fld) isa AbstractRNG) |
| 87 | +function is_same_except(m1::M1, |
| 88 | + m2::M2, |
| 89 | + exceptions::Symbol...) where {M1<:MLJType,M2<:MLJType} |
| 90 | + typeof(m1) === typeof(m2) || return false |
| 91 | + names = propertynames(m1) |
| 92 | + propertynames(m2) === names || return false |
| 93 | + |
| 94 | + for name in names |
| 95 | + if !(name in exceptions) |
| 96 | + if !isdefined(m1, name) |
| 97 | + !isdefined(m2, name) || return false |
| 98 | + elseif isdefined(m2, name) |
| 99 | + if name in deep_properties(M1) |
| 100 | + _equal_to_depth_one(getproperty(m1,name), |
| 101 | + getproperty(m2, name)) || return false |
| 102 | + else |
| 103 | + (is_same_except(getproperty(m1, name), |
| 104 | + getproperty(m2, name)) || |
| 105 | + getproperty(m1, name) isa AbstractRNG || |
| 106 | + getproperty(m2, name) isa AbstractRNG) || return false |
| 107 | + end |
| 108 | + else |
| 109 | + return false |
| 110 | + end |
| 111 | + end |
41 | 112 | end
|
42 |
| - return same_values |
| 113 | + return true |
43 | 114 | end
|
44 | 115 |
|
45 | 116 | ==(m1::M1, m2::M2) where {M1<:MLJType,M2<:MLJType} = is_same_except(m1, m2)
|
|
0 commit comments