Skip to content

Commit 3f66616

Browse files
authored
Merge pull request #45 from alan-turing-institute/ufvector
UnivariateFiniteVector
2 parents 3ed6ff9 + 6000d8c commit 3f66616

File tree

4 files changed

+24
-8
lines changed

4 files changed

+24
-8
lines changed

src/MLJModelInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export LightInterface, FullInterface
1414
# MLJ model hierarchy
1515
export MLJType, Model, Supervised, Unsupervised,
1616
Probabilistic, Deterministic, Interval, Static,
17-
UnivariateFinite
17+
UnivariateFinite, UnivariateFiniteVector
1818

1919
# model constructor + metadata
2020
export @mlj_model, metadata_pkg, metadata_model

src/data_utils.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ _squeeze(v) = first(v)
292292

293293
const UNIVARIATE_FINITE_DOCSTRING =
294294
"""
295-
UnivariateFinite(classes, p; pool=nothing, ordered=false)
295+
UnivariateFinite(classes, p; pool=nothing, ordered=false)
296296
297297
Construct a discrete univariate distribution whose finite support is
298298
the elements of the vector `classes`, and whose corresponding
@@ -329,3 +329,19 @@ UnivariateFinite(c::AbstractVector, p; kwargs...) =
329329

330330
UnivariateFinite(::LightInterface, a...; kwargs...) =
331331
errlight("UnivariateFinite")
332+
333+
const UNIVARIATE_FINITE_VECTOR_DOCSTRING =
334+
"""
335+
UnivariateFiniteVector(scores, classes)
336+
337+
Container for UnivariateFinite elements optimised for efficiency.
338+
Accessing a single element will construct and return the corresponding
339+
UnivariateFinite lazily.
340+
"""
341+
UnivariateFiniteVector(s::AbstractArray) =
342+
UnivariateFiniteVector(get_interface_mode(), s)
343+
UnivariateFiniteVector(s::AbstractArray, c) =
344+
UnivariateFiniteVector(get_interface_mode(), s, c)
345+
346+
UnivariateFiniteVector(::LightInterface, a...) =
347+
errlight("UnivariateFiniteVector")

test/data_utils.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ end
3434
end
3535
@testset "int-full" begin
3636
setfull()
37-
M.int(::FI, x::CategoricalElement; kw...) =
38-
CategoricalArrays.order(x.pool)[x.level]
37+
M.int(::FI, x::CategoricalValue; kw...) =
38+
collect(1:length(levels(x.pool)))[x.level]
3939
x = categorical(['a','b','a'])
4040
@test int(x[1]) == 0x01
4141
@test int(x[2]) == 0x02
42-
@test int(x[2]) isa UInt32
42+
@test_broken int(x[2]) isa UInt32
4343
@test int(x[1], type=Int64) == 1
4444
@test int(x[1], type=Int64) isa Int64
4545
end
@@ -52,8 +52,8 @@ end
5252
@testset "classes-full" begin
5353
setfull()
5454
M.classes(::FI, p::CategoricalPool) =
55-
[p[i] for i in invperm(CategoricalArrays.order(p))]
56-
M.classes(::FI, x::CategoricalElement) = classes(x.pool)
55+
[p[i] for i in invperm(1:length(levels(p)))]
56+
M.classes(::FI, x::CategoricalValue) = classes(x.pool)
5757
x = categorical(['a','b','a'])
5858
@test classes(x[1]) == ['a', 'b']
5959
end
@@ -230,4 +230,5 @@ end
230230
setlight()
231231
@test_throws M.InterfaceError UnivariateFinite(Dict(2=>3,3=>4))
232232
@test_throws M.InterfaceError UnivariateFinite(randn(2), randn(2))
233+
@test_throws M.InterfaceError UnivariateFiniteVector(randn(2), randn(2))
233234
end

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import MLJBase
66

77
const M = MLJModelInterface
88
const FI = M.FullInterface
9-
const CategoricalElement = Union{CategoricalValue,CategoricalString}
109
ScientificTypes.TRAIT_FUNCTION_GIVEN_NAME[:table] = Tables.istable
1110

1211
setlight() = M.set_interface_mode(M.LightInterface())

0 commit comments

Comments
 (0)