Skip to content

Commit 1bfe302

Browse files
authored
Merge pull request #122 from JuliaAI/dev
For a 1.3.3 release
2 parents 402ca25 + 5684782 commit 1bfe302

File tree

4 files changed

+38
-32
lines changed

4 files changed

+38
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "1.3.2"
4+
version = "1.3.3"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/data_utils.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,14 @@ construct an abstract *array* of `UnivariateFinite` distributions by
323323
choosing `probs` to be an array of one higher dimension than the array
324324
generated.
325325
326+
Here the word "probabilities" is an abuse of terminology as there is
327+
no requirement that probabilities actually sum to one, only that they
328+
be non-negative. So `UnivariateFinite` objects actually implement
329+
arbitrary non-negative measures over finite sets of labelled points. A
330+
`UnivariateDistribution` will be a bona fide probability measure when
331+
constructed using the `augment=true` option (see below) or when
332+
`fit` to data.
333+
326334
Unless `pool` is specified, `support` should have type
327335
`AbstractVector{<:CategoricalValue}` and all elements are assumed to
328336
share the same categorical pool, which may be larger than `support`.
@@ -335,7 +343,8 @@ If `probs` is a matrix, it should have a column for each class in
335343
`support` (or one less, if `augment=true`). More generally, `probs`
336344
will be an array whose size is of the form `(n1, n2, ..., nk, c)`,
337345
where `c = length(support)` (or one less, if `augment=true`) and the
338-
constructor then returns an array of size `(n1, n2, ..., nk)`.
346+
constructor then returns an array of `UnivariateFinite` distributions
347+
of size `(n1, n2, ..., nk)`.
339348
340349
```
341350
using CategoricalArrays
@@ -401,11 +410,12 @@ julia> UnivariateFinite([:x, :y, :z], probs, pool=v)
401410
402411
### Probability augmentation
403412
404-
Unless `augment=true`, sums of elements along the last axis (row-sums
405-
in the case of a matrix) must be equal to one, and otherwise such an
406-
array is created by inserting appropriate elements *ahead* of those
407-
provided. This means the provided probabilities are associated with
408-
the the classes `c2, c3, ..., cn`.
413+
If `augment=true` the provided array is augmented by inserting
414+
appropriate elements *ahead* of those provided, along the last
415+
dimension of the array. This means the user only provides probabilities
416+
for the classes `c2, c3, ..., cn`. The class `c1` probabilities are
417+
chosen so that each `UnivariateFinite` distribution in the returned
418+
array is a bona fide probability distribution.
409419
410420
---
411421

src/model_api.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@ fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)
1515
# fallbacks for supervised models that don't support sample weights:
1616
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)
1717

18-
# fallback for unsupervised detectors when no "evaluation" labels appear:
19-
fit(m::Union{ProbabilisticUnsupervisedDetector,
20-
DeterministicUnsupervisedDetector},
21-
verbosity,
22-
X,
23-
y) = fit(m, verbosity, X)
18+
# fallback for unsupervised annotators when labels or weights appear:
19+
# this is useful for evaluation and mixed composite models that combine
20+
# both supervised and unsupervised annotators
21+
fit(m::UnsupervisedAnnotator, verbosity, X, y) = fit(m, verbosity, X)
22+
fit(m::UnsupervisedAnnotator, verbosity, X, y, w) = fit(m, verbosity, X)
2423

2524
"""
2625
MLJModelInterface.update(model, verbosity, fitresult, cache, data...)
@@ -90,7 +89,7 @@ selectrows(::Model, I, data...) = map(X -> selectrows(X, I), data)
9089
# this operation can be optionally overloaded to provide access to
9190
# fitted parameters (eg, coeficients of linear model):
9291
"""
93-
fitted_params(model, fitresult) -> human_readable_fitresult # named_tuple
92+
fitted_params(model, fitresult) -> human_readable_fitresult # named_tuple
9493
9594
Models may overload `fitted_params`. The fallback returns
9695
`(fitresult=fitresult,)`.

src/model_traits.jl

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,8 @@ for M in ABSTRACT_MODEL_SUBTYPES
4141
@eval(StatTraits.abstract_type(::Type{<:$M}) = $M)
4242
end
4343

44-
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
45-
Tuple{input_scitype(M)}
46-
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
47-
function StatTraits.fit_data_scitype(M::Type{<:Supervised})
44+
# helper to determine the scitype of supervised models
45+
function supervised_fit_data_scitype(M)
4846
I = input_scitype(M)
4947
T = target_scitype(M)
5048
ret = Tuple{I,T}
@@ -57,21 +55,21 @@ function StatTraits.fit_data_scitype(M::Type{<:Supervised})
5755
end
5856
return ret
5957
end
60-
StatTraits.fit_data_scitype(M::Type{<:UnsupervisedAnnotator}) =
58+
59+
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
6160
Tuple{input_scitype(M)}
61+
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
62+
StatTraits.fit_data_scitype(M::Type{<:Supervised}) =
63+
supervised_fit_data_scitype(M)
64+
65+
# In special case of `UnsupervisedAnnotator`, we allow the target
66+
# as an optional argument to `fit` (that is ignored) so that the
67+
# `machine` constructor will accept it as a valid argument, which
68+
# then enables *evaluation* of the detector with labeled data:
69+
StatTraits.fit_data_scitype(M::Type{<:UnsupervisedAnnotator}) =
70+
Union{Tuple{input_scitype(M)}, supervised_fit_data_scitype(M)}
6271
StatTraits.fit_data_scitype(M::Type{<:SupervisedAnnotator}) =
63-
Tuple{input_scitype(M),target_scitype(M)}
64-
65-
# In special case of `UnsupervisedProbabilisticDetector`, and
66-
# `UnsupervsedDeterministicDetector` we allow the target as an
67-
# optional argument to `fit` (that is ignored) so that the `machine`
68-
# constructor will accept it as a valid argument, which then enables
69-
# *evaluation* of the detector with labeled data:
70-
StatTraits.fit_data_scitype(M::Type{<:Union{
71-
ProbabilisticUnsupervisedDetector,
72-
DeterministicUnsupervisedDetector}}) =
73-
Union{Tuple{input_scitype(M)},
74-
Tuple{input_scitype(M),target_scitype(M)}}
72+
supervised_fit_data_scitype(M)
7573

7674
StatTraits.transform_scitype(M::Type{<:Unsupervised}) =
7775
output_scitype(M)
@@ -82,7 +80,6 @@ StatTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
8280
StatTraits.predict_scitype(M::Type{<:Union{
8381
Deterministic,DeterministicDetector}}) = target_scitype(M)
8482

85-
8683
## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` and
8784
## `ProbabilisticDetector` MODELS
8885

0 commit comments

Comments
 (0)