File tree Expand file tree Collapse file tree 2 files changed +8
-5
lines changed Expand file tree Collapse file tree 2 files changed +8
-5
lines changed Original file line number Diff line number Diff line change 11ignore :
22 - " **/utils*"
3+ - " tests/*"
4+
35coverage :
46 status :
57 patch : false
Original file line number Diff line number Diff line change 22import numpy as np
33import harmonic as hm
44import getdist
5- from harmonic import model_legacy
6- from getdist import plots
75import matplotlib as plt
86
97
@@ -263,7 +261,7 @@ def cross_validation(
263261 domains : List ,
264262 hyper_parameters : List ,
265263 nfold = 2 ,
266- modelClass = model_legacy . KernelDensityEstimate ,
264+ modelClass = None ,
267265 seed : int = - 1 ,
268266) -> List :
269267 """Perform n-fold validation for given model using chains to be split into
@@ -285,8 +283,8 @@ def cross_validation(
285283 hyper_parameters (List): List of hyper_parameters where each entry is a
286284 hyper_parameter list to be considered.
287285
288- modelClass (Model): Model that is being cross validated (default =
289- KernelDensityEstimate).
286+ modelClass (Model): Model that is being cross validated (defaults to
287+ KernelDensityEstimate inside function ).
290288
291289 seed (int): Seed for random number generator when drawing the chains
292290 (if this is negative the seed is not set).
@@ -301,6 +299,9 @@ def cross_validation(
301299
302300 """
303301
302+ if modelClass is None :
303+ modelClass = hm .model_legacy .KernelDensityEstimate
304+
304305 ln_validation_variances = np .zeros ((nfold , len (hyper_parameters )))
305306
306307 if seed > 0 :
You can’t perform that action at this time.
0 commit comments