1- struct CombinatorialCrossValidation{T1, T2, T3, T4} <: CrossValidationEstimator
1+ struct CombinatorialCrossValidation{T1, T2, T3, T4} <: NonSequentialCrossValidationEstimator
22 n_folds:: T1
33 n_test_folds:: T2
44 purged_size:: T3
@@ -20,16 +20,25 @@ function CombinatorialCrossValidation(; n_folds::Integer = 10, n_test_folds::Int
2020 purged_size:: Integer = 0 , embargo_size:: Integer = 0 )
2121 return CombinatorialCrossValidation (n_folds, n_test_folds, purged_size, embargo_size)
2222end
23+ function n_splits (n_folds:: Integer , n_test_folds:: Integer )
24+ return binomial (n_folds, n_test_folds)
25+ end
2326function n_splits (ccv:: CombinatorialCrossValidation )
24- return binomial (ccv. n_folds, ccv. n_test_folds)
27+ return n_splits (ccv. n_folds, ccv. n_test_folds)
28+ end
29+ function n_test_paths (n_folds:: Integer , n_test_folds:: Integer )
30+ return div (n_splits (n_folds, n_test_folds) * n_test_folds, n_folds)
2531end
2632function n_test_paths (ccv:: CombinatorialCrossValidation )
2733 return div (n_splits (ccv) * ccv. n_test_folds, ccv. n_folds)
2834end
35+ function average_train_size (T:: Integer , n_folds:: Integer , n_test_folds:: Integer )
36+ return T / n_folds * (n_folds - n_test_folds)
37+ end
2938function average_train_size (ccv:: CombinatorialCrossValidation , rd:: ReturnsResult )
3039 T = size (rd. X, 1 )
3140 (; n_folds, n_test_folds) = ccv
32- return T / n_folds * (n_folds - n_test_folds)
41+ return average_train_size (T, n_folds, n_test_folds)
3342end
3443function test_set_index (ccv:: CombinatorialCrossValidation )
3544 return collect (Combinatorics. combinations (1 : (ccv. n_folds), ccv. n_test_folds))
@@ -67,10 +76,9 @@ function get_path_ids(ccv::CombinatorialCrossValidation)
6776end
6877function Base. split (ccv:: CombinatorialCrossValidation , rd:: ReturnsResult )
6978 T = size (rd. X, 1 )
70- (; n_folds, n_test_folds, purged_size, embargo_size) = ccv
79+ (; n_folds, purged_size, embargo_size) = ccv
7180 min_fold_size = div (T, n_folds)
72- pes = purged_size + embargo_size
73- @argcheck (pes < min_fold_size)
81+ @argcheck (purged_size + embargo_size < min_fold_size)
7482 fold_idx_num = div .(0 : (T - 1 ), min_fold_size)
7583 fold_idx_num[fold_idx_num .== n_folds] .= n_folds - 1
7684 fold_idx_num .+ = 1
@@ -94,21 +102,49 @@ function Base.split(ccv::CombinatorialCrossValidation, rd::ReturnsResult)
94102 after_idx = findall (x -> x == - 1 , dif)
95103 after_idx_1 = getindex .(getindex .(after_idx, 1 ))
96104 after_idx_2 = getindex .(getindex .(after_idx, 2 ))
97- for i in 1 : pes
105+ for i in 1 : (purged_size + embargo_size)
98106 j = map (x -> min (T, x + i), after_idx_1)
99107 for (j, k) in zip (j, after_idx_2)
100108 train_test_idx[j, k] = - one (num_splits)
101109 end
102110 end
103111 fold_index = Dict (i => findall (fold_idx_num .== i) for i in 1 : n_folds)
104- # ! allocate train and test induces
112+ train_idx = Vector {Vector{typeof(T)}} (undef, 0 )
113+ test_idx_list = Vector {Vector{Vector{typeof(T)}}} (undef, 0 )
105114 for i in 1 : num_splits
106- train_idx = findall (x -> x == 0 , view (train_test_idx, :, i))
107- test_idx_list = [fold_index[j[1 ]] for j in findall (x -> x == i, rcp)]
108- return train_idx, test_idx_list
115+ push! (train_idx, findall (x -> x == 0 , view (train_test_idx, :, i)))
116+ push! (test_idx_list, [fold_index[j[1 ]] for j in findall (x -> x == i, rcp)])
117+ end
118+ return train_idx, test_idx_list
119+ end
120+ function optimal_number_folds (T:: Integer , target_train_size:: Integer ,
121+ target_n_test_paths:: Integer ; train_size_w:: Number = 1 ,
122+ n_test_paths_w:: Number = 1 , maxval:: Number = 1e5 )
123+ function _cost (x:: Integer , y:: Integer )
124+ return n_test_paths_w * abs (n_test_paths (x, y) - target_n_test_paths) /
125+ target_n_test_paths +
126+ train_size_w * abs (average_train_size (T, x, y) - target_train_size) /
127+ target_train_size
128+ end
129+ costs = Vector{promote_type (typeof (train_size_w), typeof (n_test_paths_w),
130+ typeof (maxval))}(undef, 0 )
131+ type = promote_type (typeof (T), typeof (target_train_size), typeof (target_n_test_paths))
132+ res = Vector {Tuple{type, type}} (undef, 0 )
133+ for n_folds in 3 : (T + 1 )
134+ i = nothing
135+ for n_test_folds in 2 : n_folds
136+ if ! (isnothing (i) || n_folds - n_test_folds <= i)
137+ continue
138+ end
139+ cost = _cost (n_folds, n_test_folds)
140+ push! (costs, cost)
141+ push! (res, (n_folds, n_test_folds))
142+ if isnothing (i) && cost > maxval
143+ i = n_test_folds
144+ end
145+ end
109146 end
110- return fold_index
147+ return res[ argmin (costs)]
111148end
112149
113- export CombinatorialCrossValidation, n_test_paths, average_train_size, test_set_index,
114- binary_train_test_sets, recombined_paths, get_path_ids
150+ export CombinatorialCrossValidation, optimal_number_folds
0 commit comments