Skip to content

Commit 076cbd2

Browse files
authored
Merge pull request #13 from JuliaAI/untether-distributions
Free UnivariateFinite from being a subtype of Distributions.Distribution
2 parents c785a8a + ff41c7b commit 076cbd2

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

src/methods.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,9 @@ function Dist.fit(d::Type{<:UnivariateFinite},
371371
end
372372

373373

374+
# # BROADCASTING OVER SINGLE UNIVARIATE FINITE
375+
376+
# This mirrors behaviour assigned Distributions.Distribution objects,
377+
# which allows `pdf.(d::UnivariateFinite, support(d))` to work.
378+
379+
Broadcast.broadcastable(d::UnivariateFinite) = Ref(d)

src/types.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,22 +129,14 @@ same size as the array.
129129

130130
# # TYPES - PLAIN AND ARRAY
131131

132-
# extend Ditributions type hiearchy to account for non-euclidean
133-
# supports:
134-
abstract type Categorical{S<:Finite} <: Dist.ValueSupport end
135-
136-
# not exported:
137-
const _UnivariateFinite_{S} =
138-
Dist.Distribution{Dist.Univariate,Categorical{S}}
139-
140132
# R - reference type <: Unsigned
141133
# V - type of class labels (eg, Char in `categorical(['a', 'b'])`)
142134
# P - raw probability type
143135
# S - scitype of samples
144136

145137
# Note that the keys of `prob_given_ref` need not exhaust all the
146138
# refs of all classes but will be ordered (LittleDicts preserve order)
147-
struct UnivariateFinite{S,V,R,P} <: _UnivariateFinite_{S}
139+
struct UnivariateFinite{S,V,R,P}
148140
scitype::Type{S}
149141
decoder::CategoricalDecoder{V,R}
150142
prob_given_ref::LittleDict{R,P,Vector{R}, Vector{P}}

test/methods.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
121121
@test pdf(d, "class_1") == 0.7
122122
end
123123

124+
@testset "broadcasting pdf over single UnivariateFinite object" begin
125+
d = UnivariateFinite(["a", "b"], [0.1, 0.9], pool=missing);
126+
@test pdf.(d, ["a", "b"]) == [0.1, 0.9]
127+
end
128+
124129
@testset "constructor arguments not categorical values" begin
125130
@test_throws ArgumentError UnivariateFinite(Dict('f'=>0.7, 'q'=>0.2))
126131
@test_throws ArgumentError UnivariateFinite(Dict('f'=>0.7, 'q'=>0.2),
@@ -277,7 +282,6 @@ end
277282
# @test v ≈ v_close
278283
end
279284

280-
281285
end # module
282286

283287
true

0 commit comments

Comments
 (0)