diff --git a/src/schema.jl b/src/schema.jl index 7ef8e88f..22dcab70 100644 --- a/src/schema.jl +++ b/src/schema.jl @@ -53,6 +53,28 @@ Base.merge!(a::Schema, b::Schema) = (merge!(a.schema, b.schema); a) Base.keys(schema::Schema) = keys(schema.schema) Base.haskey(schema::Schema, key) = haskey(schema.schema, key) +function ==(first::Schema, second::Schema) + first === second && return true + first.schema === second.schema && return true + length(first.schema) != length(second.schema) && return false + for key in keys(first) + !haskey(second, key) && return false + second[key] != first[key] && return false + end + true +end + +function Base.isequal(first::Schema, second::Schema) + first === second && return true + first.schema === second.schema && return true + length(first.schema) != length(second.schema) && return false + for key in keys(first) + !haskey(second, key) && return false + !isequal(second[key], first[key]) && return false + end + true +end + """ schema([terms::AbstractVector{<:AbstractTerm}, ]data, hints::Dict{Symbol}) schema(term::AbstractTerm, data, hints::Dict{Symbol}) diff --git a/src/terms.jl b/src/terms.jl index 2ee5f785..2fe788a6 100644 --- a/src/terms.jl +++ b/src/terms.jl @@ -1,3 +1,4 @@ +import Base.== , Base.isequal abstract type AbstractTerm end const TermOrTerms = Union{AbstractTerm, NTuple{N, AbstractTerm} where N} const TupleTerm = NTuple{N, TermOrTerms} where N @@ -38,6 +39,8 @@ struct ConstantTerm{T<:Number} <: AbstractTerm end width(::ConstantTerm) = 1 +==(first::ConstantTerm, second::ConstantTerm) = first.n == second.n +isequal(first::ConstantTerm, second::ConstantTerm) = isequal(first.n, second.n) """ FormulaTerm{L,R} <: AbstractTerm @@ -54,6 +57,13 @@ struct FormulaTerm{L,R} <: AbstractTerm rhs::R end +==(first::FormulaTerm, second::FormulaTerm) = + first.lhs == second.lhs && + first.rhs == second.rhs +isequal(first::FormulaTerm, second::FormulaTerm) = + isequal(first.lhs, second.lhs) && + isequal(first.rhs, second.rhs) + """ FunctionTerm{Forig,Fanon,Names} <: AbstractTerm @@ -127,6 +137,13 @@ FunctionTerm(forig::Fo, fanon::Fa, names::NTuple{N,Symbol}, FunctionTerm{Fo, Fa, names}(forig, fanon, exorig, args_parsed) width(::FunctionTerm) = 1 +==(first::FunctionTerm, second::FunctionTerm) = + first.forig == second.forig && + first.args_parsed == second.args_parsed +isequal(first::FunctionTerm, second::FunctionTerm) = + isequal(first.forig, second.forig) && + isequal(first.args_parsed, second.args_parsed) + """ InteractionTerm{Ts} <: AbstractTerm @@ -174,6 +191,10 @@ struct InteractionTerm{Ts} <: AbstractTerm end width(ts::InteractionTerm) = prod(width(t) for t in ts.terms) +==(first::InteractionTerm, second::InteractionTerm) = + first.terms == second.terms +isequal(first::InteractionTerm, second::InteractionTerm) = + isequal(first.terms, second.terms) """ InterceptTerm{HasIntercept} <: AbstractTerm @@ -187,6 +208,11 @@ via the [`implicit_intercept`](@ref) trait). struct InterceptTerm{HasIntercept} <: AbstractTerm end width(::InterceptTerm{H}) where {H} = H ? 1 : 0 +==(first::InterceptTerm, second::InterceptTerm) = + width(first) == width(second) +isequal(first::InterceptTerm, second::InterceptTerm) = + isequal(width(first), width(second)) + # Typed terms """ @@ -211,6 +237,19 @@ struct ContinuousTerm{T} <: AbstractTerm end width(::ContinuousTerm) = 1 +==(first::ContinuousTerm, second::ContinuousTerm) = + first.sym == second.sym && + first.mean == second.mean && + first.var == second.var && + first.min == second.min && + first.max == second.max + +isequal(first::ContinuousTerm, second::ContinuousTerm) = + isequal(first.sym, second.sym) && + isequal(first.mean, second.mean) && + isequal(first.var, second.var) && + isequal(first.min, second.min) && + isequal(first.max, second.max) """ CategoricalTerm{C,T,N} <: AbstractTerm @@ -233,6 +272,14 @@ width(::CategoricalTerm{C,T,N}) where {C,T,N} = N CategoricalTerm(sym::Symbol, contrasts::ContrastsMatrix{C,T}) where {C,T} = CategoricalTerm{C,T,length(contrasts.termnames)}(sym, contrasts) +==(first::CategoricalTerm, second::CategoricalTerm) = + first.sym == second.sym && + width(first) == width(second) && + first.contrasts == second.contrasts +isequal(first::CategoricalTerm, second::CategoricalTerm) = + isequal(first.sym, second.sym) && + isequal(width(first), width(second)) && + isequal(first.contrasts, second.contrasts) """ MatrixTerm{Ts} <: AbstractTerm @@ -250,6 +297,11 @@ end MatrixTerm(t::AbstractTerm) = MatrixTerm((t, )) width(t::MatrixTerm) = sum(width(tt) for tt in t.terms) +==(first::MatrixTerm, second::MatrixTerm) = + first.terms == second.terms +isequal(first::MatrixTerm, second::MatrixTerm) = + isequal(first.terms, second.terms) + """ collect_matrix_terms(ts::TupleTerm) collect_matrix_terms(t::AbstractTerm) = collect_matrix_term((t, )) diff --git a/test/schema.jl b/test/schema.jl index 7786f44b..b4521648 100644 --- a/test/schema.jl +++ b/test/schema.jl @@ -1,10 +1,47 @@ @testset "schemas" begin - using StatsModels: schema, apply_schema, FullRank - f = @formula(y ~ 1 + a + b + c + b&c) - df = (y = rand(9), a = 1:9, b = rand(9), c = repeat(["d","e","f"], 3)) + f = @formula(y ~ 1 + a + log(b) + c + b & c) + y = rand(9) + b = rand(9) + + df = (y = y, a = 1:9, b = b, c = repeat(["d", "e", "f"], 3)) f = apply_schema(f, schema(f, df)) @test f == apply_schema(f, schema(f, df)) + df2 = (y = y, a = 1:9, b = b, c = [df.c; df.c]) + df3 = (y = y, a = 1:9, b = b, c = repeat(["a", "b", "c"], 3)) + df4 = (y = [df.y; df.y], a = [1:9; 1:9], b = [b; b], c = [df.c; df.c]) + df5 = (z = y, a = 1:9, b = b, c = repeat(["d", "e", "f"], 3)) + df6 = (y = y, a = 2:10, b = b, c = repeat(["a", "b", "c"], 3)) + df7 = (w = y, d = 1:9, x = b, z = repeat(["d", "e", "f"], 3)) + df8 = (y = y, a = 1:9, c = repeat(["d", "e", "f"], 3)) + + sch = schema(df, Dict(:c => DummyCoding(base="e"))) + sch2 = schema(df, Dict(:c => EffectsCoding(base="e"))) + + @test schema(df) == schema(df2) + @test apply_schema(f, schema(df)) == apply_schema(f, schema(df2)) + @test schema(df) != schema(df3) + @test schema(df) != schema(df4) + @test schema(df) != schema(df5) + @test schema(df) != schema(df6) + @test schema(df) != schema(df7) + @test schema(df) != schema(df8) + @test schema(df8) != schema(df) + @test apply_schema(f, schema(df)) == apply_schema(f, schema(df5)) + @test sch != sch2 + + @test isequal(schema(df), schema(df2)) + @test isequal(apply_schema(f, schema(df)), apply_schema(f, schema(df2))) + @test !isequal(schema(df), schema(df3)) + @test !isequal(schema(df), schema(df4)) + @test !isequal(schema(df), schema(df5)) + @test !isequal(schema(df), schema(df6)) + @test !isequal(schema(df), schema(df7)) + @test !isequal(schema(df), schema(df8)) + @test !isequal(schema(df8), schema(df)) + @test isequal(apply_schema(f, schema(df)), apply_schema(f, schema(df5))) + @test !isequal(sch, sch2) + end