|
| 1 | +""" |
| 2 | + is_same_except(m1::MLJType, m2::MLJType, exceptions::Symbol...) |
| 3 | +
|
| 4 | +Returns `true` only the following conditions all hold: |
| 5 | +
|
| 6 | +- `m1` and `m2` have the same type. |
| 7 | +
|
| 8 | +- `m1` and `m2` have the same undefined fields. |
| 9 | +
|
| 10 | +- Corresponding fields agree, or are listed as `exceptions`, or have |
| 11 | + `AbstractRNG` as values (one or both) |
| 12 | +
|
| 13 | +Here "agree" is in the sense of "==", unless the objects are themselves of |
| 14 | +`MLJType`, in which case agreement is in the sense of `is_same_except` with |
| 15 | +no exceptions allowed. |
| 16 | +
|
| 17 | +Note that Base.== is overloaded such that `m1 == m2` if and only if |
| 18 | +`is_same_except(m1, m2)`. |
| 19 | +""" |
| 20 | +is_same_except(x1, x2) = ==(x1, x2) |
| 21 | +function is_same_except(m1::M1, m2::M2, |
| 22 | + exceptions::Symbol...) where {M1<:MLJType,M2<:MLJType} |
| 23 | + if typeof(m1) != typeof(m2) |
| 24 | + return false |
| 25 | + end |
| 26 | + defined1 = filter(fieldnames(M1)|>collect) do fld |
| 27 | + isdefined(m1, fld) && !(fld in exceptions) |
| 28 | + end |
| 29 | + defined2 = filter(fieldnames(M1)|>collect) do fld |
| 30 | + isdefined(m2, fld) && !(fld in exceptions) |
| 31 | + end |
| 32 | + if defined1 != defined2 |
| 33 | + return false |
| 34 | + end |
| 35 | + same_values = true |
| 36 | + for fld in defined1 |
| 37 | + same_values = same_values && |
| 38 | + (is_same_except(getfield(m1, fld), getfield(m2, fld)) || |
| 39 | + getfield(m1, fld) isa AbstractRNG || |
| 40 | + getfield(m2, fld) isa AbstractRNG) |
| 41 | + end |
| 42 | + return same_values |
| 43 | +end |
| 44 | + |
| 45 | +==(m1::M1, m2::M2) where {M1<:MLJType,M2<:MLJType} = is_same_except(m1, m2) |
| 46 | + |
| 47 | +# for using `replace` or `replace!` on collections of MLJType objects |
| 48 | +# (eg, Model objects in a learning network) we need a stricter |
| 49 | +# equality and a corresponding definition of `in`. |
| 50 | +Base.isequal(m1::MLJType, m2::MLJType) = (m1 === m2) |
| 51 | + |
| 52 | +# Note: To prevent julia crash, it seems we need to annotate the type |
| 53 | +# of itr: |
| 54 | +function special_in(x, itr)::Union{Bool,Missing} |
| 55 | + for y in itr |
| 56 | + ismissing(y) && return missing |
| 57 | + y === x && return true |
| 58 | + end |
| 59 | + return false |
| 60 | +end |
| 61 | +Base.in(x::MLJType, itr::Set) = special_in(x, itr) |
| 62 | +Base.in(x::MLJType, itr::AbstractVector) = special_in(x, itr) |
| 63 | +Base.in(x::MLJType, itr::Tuple) = special_in(x, itr) |
0 commit comments