@@ -237,30 +237,15 @@ julia> res_input = minimal_init(8, 3; p=0.8)# higher p -> more positive signs
237237```
238238"""
239239function minimal_init (rng:: AbstractRNG , :: Type{T} , dims:: Integer... ;
240- sampling_type:: Symbol = :bernoulli , weight:: Number = T (0.1 ), irrational:: Real = pi ,
241- start:: Int = 1 , p:: Number = T (0.5 )) where {T <: Number }
240+ sampling_type:: Symbol = :bernoulli , kwargs... ) where {T <: Number }
242241 res_size, in_size = dims
243- if sampling_type == :bernoulli
244- layer_matrix = _create_bernoulli (p, res_size, in_size, weight, rng, T)
245- elseif sampling_type == :irrational
246- layer_matrix = _create_irrational (irrational,
247- start,
248- res_size,
249- in_size,
250- weight,
251- rng,
252- T)
253- else
254- error (""" \n
255- Sampling type not allowed.
256- Please use one of :bernoulli or :irrational\n
257- """ )
258- end
242+ f_sample = getfield (@__MODULE__ , sampling_type)
243+ layer_matrix = f_sample (rng, T, res_size, in_size; kwargs... )
259244 return layer_matrix
260245end
261246
262- function _create_bernoulli (p :: Number , res_size :: Int , in_size :: Int , weight :: Number ,
263- rng :: AbstractRNG , :: Type{T} ) where {T <: Number }
247+ function bernoulli (rng :: AbstractRNG , :: Type{T} , res_size :: Int , in_size :: Int ;
248+ weight :: Number = T ( 0.1 ), p :: Number = T ( 0.5 ) ) where {T <: Number }
264249 input_matrix = DeviceAgnostic. zeros (rng, T, res_size, in_size)
265250 for i in 1 : res_size
266251 for j in 1 : in_size
@@ -274,9 +259,9 @@ function _create_bernoulli(p::Number, res_size::Int, in_size::Int, weight::Numbe
274259 return input_matrix
275260end
276261
277- function _create_irrational ( irrational:: Irrational , start :: Int , res_size:: Int ,
278- in_size :: Int , weight :: Number , rng :: AbstractRNG ,
279- :: Type{T} ) where {T <: Number }
262+ function irrational (rng :: AbstractRNG , :: Type{T} , res_size:: Int , in_size :: Int ;
263+ irrational :: Irrational = pi , start :: Int = 1 ,
264+ weight :: Number = T ( 0.1 ) ) where {T <: Number }
280265 setprecision (BigFloat, Int (ceil (log2 (10 ) * (res_size * in_size + start + 1 ))))
281266 ir_string = string (BigFloat (irrational)) |> collect
282267 deleteat! (ir_string, findall (x -> x == ' .' , ir_string))
0 commit comments