Skip to content

Commit e7afc34

Browse files
authored
Merge pull request #980 from JuliaAI/constructor
Have wrappers overload `constructor` trait
2 parents 14441aa + 52d16df commit e7afc34

File tree

9 files changed

+38
-23
lines changed

9 files changed

+38
-23
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "1.3"
4+
version = "1.4.0"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -47,7 +47,7 @@ DelimitedFiles = "1"
4747
Distributions = "0.25.3"
4848
InvertedIndices = "1"
4949
LearnAPI = "0.1"
50-
MLJModelInterface = "1.7"
50+
MLJModelInterface = "1.10"
5151
Missings = "0.4, 1"
5252
OrderedCollections = "1.1"
5353
Parameters = "0.12"
@@ -58,7 +58,7 @@ Reexport = "1.2"
5858
ScientificTypes = "3"
5959
StatisticalMeasures = "0.1.1"
6060
StatisticalMeasuresBase = "0.1.1"
61-
StatisticalTraits = "3.2"
61+
StatisticalTraits = "3.3"
6262
Statistics = "1"
6363
StatsBase = "0.32, 0.33, 0.34"
6464
Tables = "0.2, 1.0"

src/composition/models/pipelines.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,9 @@ end
599599

600600
MMI.target_scitype(p::SupervisedPipeline) = target_scitype(supervised_component(p))
601601

602+
MMI.package_name(::Type{<:SomePipeline}) = "MLJBase"
603+
MMI.load_path(::Type{<:SomePipeline}) = "MLJBase.Pipeline"
604+
MMI.constructor(::Type{<:SomePipeline}) = Pipeline
602605

603606
# ## Training losses
604607

src/composition/models/stacking.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,16 +264,17 @@ function Base.setproperty!(stack::Stack{modelnames}, _name::Symbol, val) where m
264264
end
265265

266266

267+
# # TRAITS
268+
267269
MMI.target_scitype(::Type{<:Stack{modelnames, input_scitype, target_scitype}}) where
268270
{modelnames, input_scitype, target_scitype} = target_scitype
269271

270272

271273
MMI.input_scitype(::Type{<:Stack{modelnames, input_scitype, target_scitype}}) where
272274
{modelnames, input_scitype, target_scitype} = input_scitype
273275

274-
275-
MLJBase.load_path(::Type{<:ProbabilisticStack}) = "MLJBase.ProbabilisticStack"
276-
MLJBase.load_path(::Type{<:DeterministicStack}) = "MLJBase.DeterministicStack"
276+
MMI.constructor(::Type{<:Stack}) = Stack
277+
MLJBase.load_path(::Type{<:Stack}) = "MLJBase.Stack"
277278
MLJBase.package_name(::Type{<:Stack}) = "MLJBase"
278279
MLJBase.package_uuid(::Type{<:Stack}) = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
279280
MLJBase.package_url(::Type{<:Stack}) = "https://github.com/JuliaAI/MLJBase.jl"

src/composition/models/transformed_target_model.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ const TT_SUPPORTED_ATOMS = (
1010
:Deterministic,
1111
:DeterministicUnsupervisedDetector,
1212
:DeterministicSupervisedDetector,
13-
:Interval)
13+
:Interval,
14+
)
1415

1516
# Each supported atomic type gets its own wrapper:
1617

@@ -265,6 +266,10 @@ MMI.package_uuid(::Type{<:SomeTT}) = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
265266
MMI.is_wrapper(::Type{<:SomeTT}) = true
266267
MMI.package_url(::Type{<:SomeTT}) = "https://github.com/JuliaAI/MLJBase.jl"
267268

269+
MMI.load_path(::Type{<:SomeTT}) = "MLJBase.TransformedTargetModel"
270+
MMI.constructor(::Type{<:SomeTT}) = TransformedTargetModel
271+
272+
268273
for New in TT_TYPE_EXS
269274
quote
270275
MMI.iteration_parameter(::Type{<:$New{M}}) where M =

src/resampling.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,9 +1548,11 @@ end
15481548
compact=false,
15491549
)
15501550
1551+
*Private method.* Use at own risk.
1552+
15511553
Resampling model wrapper, used internally by the `fit` method of `TunedModel` instances
1552-
and `IteratedModel` instances. See [`evaluate!](@ref) for options. Not intended for use by
1553-
general user, who will ordinarily use [`evaluate!`](@ref) directly.
1554+
and `IteratedModel` instances. See [`evaluate!`](@ref) for meaning of the options. Not
1555+
intended for use by general user, who will ordinarily use [`evaluate!`](@ref) directly.
15541556
15551557
Given a machine `mach = machine(resampler, args...)` one obtains a performance evaluation
15561558
of the specified `model`, performed according to the prescribed `resampling` strategy and
@@ -1592,16 +1594,6 @@ mutable struct Resampler{S, L} <: Model
15921594
compact::Bool
15931595
end
15941596

1595-
# Some traits are markded as `missing` because we cannot determine
1596-
# them from from the type because we have removed `M` (for "model"} as
1597-
# a `Resampler` type parameter. See
1598-
# https://github.com/JuliaAI/MLJTuning.jl/issues/141#issue-951221466
1599-
1600-
StatisticalTraits.is_wrapper(::Type{<:Resampler}) = true
1601-
StatisticalTraits.supports_weights(::Type{<:Resampler}) = missing
1602-
StatisticalTraits.supports_class_weights(::Type{<:Resampler}) = missing
1603-
StatisticalTraits.is_pure_julia(::Type{<:Resampler}) = true
1604-
16051597
function MLJModelInterface.clean!(resampler::Resampler)
16061598
warning = ""
16071599
if resampler.measure === nothing && resampler.model !== nothing
@@ -1787,11 +1779,16 @@ function MLJModelInterface.update(
17871779

17881780
end
17891781

1790-
# The input and target scitypes cannot be determined from the type
1791-
# because we have removed `M` (for "model") as a `Resampler` type
1792-
# parameter. See
1782+
# Some traits are marked as `missing` because we cannot determine
1783+
# them from from the type because we have removed `M` (for "model"} as
1784+
# a `Resampler` type parameter. See
17931785
# https://github.com/JuliaAI/MLJTuning.jl/issues/141#issue-951221466
17941786

1787+
StatisticalTraits.is_wrapper(::Type{<:Resampler}) = true
1788+
StatisticalTraits.supports_weights(::Type{<:Resampler}) = missing
1789+
StatisticalTraits.supports_class_weights(::Type{<:Resampler}) = missing
1790+
StatisticalTraits.is_pure_julia(::Type{<:Resampler}) = true
1791+
StatisticalTraits.constructor(::Type{<:Resampler}) = Resampler
17951792
StatisticalTraits.input_scitype(::Type{<:Resampler}) = Unknown
17961793
StatisticalTraits.target_scitype(::Type{<:Resampler}) = Unknown
17971794
StatisticalTraits.package_name(::Type{<:Resampler}) = "MLJBase"

test/composition/models/pipelines.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ end
9595

9696
@testset "public constructor" begin
9797
# un-named components:
98-
@test Pipeline(m, t, u) isa UnsupervisedPipeline
98+
flute = Pipeline(m, t, u)
99+
@test flute isa UnsupervisedPipeline
100+
@test MLJBase.constructor(flute) == Pipeline
99101
@test Pipeline(m, t, u, p) isa ProbabilisticPipeline
100102
@test Pipeline(m, t, u, p, operation=predict_mean) isa DeterministicPipeline
101103
@test Pipeline(u, p, u, operation=predict_mean) isa DeterministicPipeline

test/composition/models/stacking.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ end
202202
measures=rmse,
203203
resampling=CV(;nfolds=3),
204204
models...)
205+
206+
@test MLJBase.constructor(mystack) == Stack
207+
205208
@test mystack.ridge_lambda.lambda == 0.1
206209
@test mystack.metalearner isa FooBarRegressor
207210
@test mystack.resampling isa CV

test/composition/models/transformed_target_model.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ avg_nonlinear = g(mean(f(y))) # = g(mean(z))
8686

8787
# Test wrapping using f and g:
8888
model = TransformedTargetModel(atom, transformer=f, inverse=g)
89+
@test MLJBase.constructor(model) == TransformedTargetModel
8990
fr1, _, _ = MMI.fit(model, 0, X, y)
9091
@test first(predict(model, fr1, X)) fill(avg_nonlinear, 5)
9192

test/resampling.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,9 @@ end
606606
holdout = Holdout(fraction_train=0.75)
607607
resampler = Resampler(resampling=holdout, model=ridge_model, measure=mae,
608608
acceleration=accel)
609+
@test constructor(resampler) == Resampler
610+
@test package_name(resampler) == "MLJBase"
611+
@test load_path(resampler) == "MLJBase.Resampler"
609612
resampling_machine = machine(resampler, X, y)
610613
@test_logs((:info, r"^Training"), fit!(resampling_machine))
611614
e1=evaluate(resampling_machine).measurement[1]

0 commit comments

Comments
 (0)