Skip to content

Commit da22c22

Browse files
committed
Remove Dirichlet 'fitting'
1 parent 93d6275 commit da22c22

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

src/parameters.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,17 @@ function _initialize(rng, t::Tuple{ParamRange, Any}, n)
3333
return _initialize(rng, t[1], t[2], n)
3434
end
3535

36-
# Initialize parameters with default distribution types
36+
# Initialize parameters with default distributions
3737

38-
_initialize(rng, r::ParamRange, n) = _initialize(rng, r, _initializer(r), n)
39-
40-
_initializer(::NominalRange) = Dirichlet
38+
function _initialize(rng, r::NominalRange{T, N}, n) where {T, N}
39+
d = Dirichlet(ones(N))
40+
return _initialize(rng, r, d, n)
41+
end
4142

42-
_initializer(r::NumericRange) = _initializer(MLJTuning.boundedness(r))
43+
function _initialize(rng, r::NumericRange, n)
44+
D = _initializer(MLJTuning.boundedness(r))
45+
return _initialize(rng, r, D, n)
46+
end
4347

4448
_initializer(::Type{MLJBase.Bounded}) = Uniform
4549

@@ -53,11 +57,6 @@ function _initialize(rng, r::ParamRange, D::Type{<:Distribution}, n)
5357
throw(ArgumentError("$D distribution is unsupported for $(typeof(r))."))
5458
end
5559

56-
function _initialize(rng, r::NominalRange{T, N}, D::Type{Dirichlet}, n) where {T, N}
57-
d = Dirichlet(ones(N))
58-
return _initialize(rng, r, d, n)
59-
end
60-
6160
function _initialize(rng, r::NumericRange, D::Type{<:UnivariateDistribution}, n)
6261
d = Distributions.fit(D, r)
6362
return _initialize(rng, r, d, n)

0 commit comments

Comments
 (0)