Skip to content

Commit 2e3f347

Browse files
committed
.
1 parent 3ed6ff9 commit 2e3f347

File tree

4 files changed

+32
-5
lines changed

4 files changed

+32
-5
lines changed

src/MLJModelInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export LightInterface, FullInterface
1414
# MLJ model hierarchy
1515
export MLJType, Model, Supervised, Unsupervised,
1616
Probabilistic, Deterministic, Interval, Static,
17-
UnivariateFinite
17+
UnivariateFinite, UnivariateFiniteVector
1818

1919
# model constructor + metadata
2020
export @mlj_model, metadata_pkg, metadata_model

src/data_utils.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ _squeeze(v) = first(v)
292292

293293
const UNIVARIATE_FINITE_DOCSTRING =
294294
"""
295-
UnivariateFinite(classes, p; pool=nothing, ordered=false)
295+
UnivariateFinite(classes, p; pool=nothing, ordered=false)
296296
297297
Construct a discrete univariate distribution whose finite support is
298298
the elements of the vector `classes`, and whose corresponding
@@ -329,3 +329,17 @@ UnivariateFinite(c::AbstractVector, p; kwargs...) =
329329

330330
UnivariateFinite(::LightInterface, a...; kwargs...) =
331331
errlight("UnivariateFinite")
332+
333+
const UNIVARIATE_FINITE_VECTOR_DOCSTRING =
334+
"""
335+
UnivariateFiniteVector(scores, classes)
336+
337+
Container for UnivariateFinite elements optimised for efficiency.
338+
Accessing a single element will construct and return the corresponding
339+
UnivariateFinite lazily.
340+
"""
341+
UnivariateFiniteVector(s::AbstractArray, a...) =
342+
UnivariateFiniteVector(get_interface_mode(), s, a...)
343+
344+
UnivariateFiniteVector(::LightInterface, a...) =
345+
errlight("UnivariateFiniteVector")

test/data_utils.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ end
3434
end
3535
@testset "int-full" begin
3636
setfull()
37-
M.int(::FI, x::CategoricalElement; kw...) =
37+
M.int(::FI, x::CategoricalValue; kw...) =
3838
CategoricalArrays.order(x.pool)[x.level]
3939
x = categorical(['a','b','a'])
4040
@test int(x[1]) == 0x01
@@ -53,7 +53,7 @@ end
5353
setfull()
5454
M.classes(::FI, p::CategoricalPool) =
5555
[p[i] for i in invperm(CategoricalArrays.order(p))]
56-
M.classes(::FI, x::CategoricalElement) = classes(x.pool)
56+
M.classes(::FI, x::CategoricalValue) = classes(x.pool)
5757
x = categorical(['a','b','a'])
5858
@test classes(x[1]) == ['a', 'b']
5959
end
@@ -230,4 +230,18 @@ end
230230
setlight()
231231
@test_throws M.InterfaceError UnivariateFinite(Dict(2=>3,3=>4))
232232
@test_throws M.InterfaceError UnivariateFinite(randn(2), randn(2))
233+
@test_throws M.InterfaceError UnivariateFiniteVector(randn(2), randn(2))
234+
235+
setfull()
236+
yc = categorical([1,2])
237+
c = classes(yc[1])
238+
s = rand()
239+
u = UnivariateFinite(c, [1-s, s])
240+
@test u isa MLJBase.UnivariateFinite
241+
@test MLJBase.pdf(u, c[2]) == s
242+
243+
s = rand(5)
244+
u = UnivariateFiniteVector(s)
245+
@test u isa MLJBase.UnivariateFiniteVector
246+
@test MLJBase.pdf.(u, u.classes[2]) == s
233247
end

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import MLJBase
66

77
const M = MLJModelInterface
88
const FI = M.FullInterface
9-
const CategoricalElement = Union{CategoricalValue,CategoricalString}
109
ScientificTypes.TRAIT_FUNCTION_GIVEN_NAME[:table] = Tables.istable
1110

1211
setlight() = M.set_interface_mode(M.LightInterface())

0 commit comments

Comments
 (0)