|
| 1 | +struct CombinatorialCrossValidation{T1, T2, T3, T4} <: CrossValidationEstimator |
| 2 | + n_folds::T1 |
| 3 | + n_test_folds::T2 |
| 4 | + purged_size::T3 |
| 5 | + embargo_size::T4 |
| 6 | + function CombinatorialCrossValidation(n_folds::Integer, n_test_folds::Integer, |
| 7 | + purged_size::Integer, embargo_size::Integer) |
| 8 | + assert_nonempty_gt0_finite_val(n_folds, :n_folds) |
| 9 | + assert_nonempty_gt0_finite_val(n_test_folds, :n_test_folds) |
| 10 | + assert_nonempty_finite_val(purged_size, :purged_size) |
| 11 | + assert_nonempty_finite_val(embargo_size, :embargo_size) |
| 12 | + return new{typeof(n_folds), typeof(n_test_folds), typeof(purged_size), |
| 13 | + typeof(embargo_size)}(n_folds, n_test_folds, purged_size, embargo_size) |
| 14 | + end |
| 15 | +end |
| 16 | +function CombinatorialCrossValidation(; n_folds::Integer = 10, n_test_folds::Integer = 8, |
| 17 | + purged_size::Integer = 0, embargo_size::Integer = 0) |
| 18 | + return CombinatorialCrossValidation(n_folds, n_test_folds, purged_size, embargo_size) |
| 19 | +end |
| 20 | +function Base.split(ccv::CombinatorialCrossValidation, rd::ReturnsResult) |
| 21 | + #= |
| 22 | + T = size(rd.X, 1) |
| 23 | + (; n_folds, n_test_folds, purged_size, embargo_size) = ccv |
| 24 | + idx = 1:T |
| 25 | + min_fold_size = div(T, n_folds) |
| 26 | + @argcheck(purged_size + embargo_size < min_fold_size) |
| 27 | + fold_sizes = fill(min_fold_size, n_folds) |
| 28 | + fold_sizes[1:(mod(T, n_folds))] .+= one(eltype(fold_sizes)) |
| 29 | + fold_indices = Vector{typeof(idx)}(undef, 0) |
| 30 | + current = one(eltype(fold_sizes)) |
| 31 | + for fold_size in fold_sizes |
| 32 | + start, stop = current, current + fold_size |
| 33 | + push!(fold_indices, idx[start:(stop - 1)]) |
| 34 | + current = stop |
| 35 | + end |
| 36 | + test_indices = Vector{typeof(idx)}(undef, 0) |
| 37 | + for test_fold in combinations(1:n_folds, n_test_folds) |
| 38 | + push!(test_indices, vcat(fold_indices[test_fold]...)) |
| 39 | + end |
| 40 | + train_indices = Vector{Vector{eltype(T)}}(undef, 0) |
| 41 | + for test_fold in combinations(1:n_folds, n_test_folds) |
| 42 | + tmp_test_idx = Vector{typeof(idx)}(undef, 0) |
| 43 | + for j in test_fold |
| 44 | + if j == minimum(test_fold) - 1 |
| 45 | + push!(tmp_test_idx, fold_indices[j][1:(end - purged_size)]) |
| 46 | + elseif j == maximum(test_fold) + 1 |
| 47 | + push!(tmp_test_idx, |
| 48 | + fold_indices[j][(1 + purged_size + embargo_size):end]) |
| 49 | + else |
| 50 | + push!(tmp_test_idx, fold_indices[j]) |
| 51 | + end |
| 52 | + end |
| 53 | + push!(train_indices, |
| 54 | + setdiff(idx, vcat(tmp_test_idx..., fold_indices[test_fold]...))) |
| 55 | + end |
| 56 | + return train_indices, test_indices |
| 57 | + =# |
| 58 | +end |
| 59 | +function n_splits(ccv::CombinatorialCrossValidation) |
| 60 | + return binomial(ccv.n_folds, ccv.n_test_folds) |
| 61 | +end |
| 62 | +function n_test_paths(ccv::CombinatorialCrossValidation) |
| 63 | + return div(n_splits(ccv) * ccv.n_test_folds, ccv.n_folds) |
| 64 | +end |
| 65 | +function average_train_size(ccv::CombinatorialCrossValidation, rd::ReturnsResult) |
| 66 | + T = size(rd.X, 1) |
| 67 | + (; n_folds, n_test_folds) = ccv |
| 68 | + return T / n_folds * (n_folds - n_test_folds) |
| 69 | +end |
| 70 | +function test_set_index(ccv::CombinatorialCrossValidation) |
| 71 | + return collect(Combinatorics.combinations(1:(ccv.n_folds), ccv.n_test_folds)) |
| 72 | +end |
| 73 | +function binary_train_test_sets(ccv::CombinatorialCrossValidation) |
| 74 | + n_folds = ccv.n_folds |
| 75 | + num_splits = n_splits(ccv) |
| 76 | + type = promote_type(typeof(num_splits), typeof(n_folds)) |
| 77 | + folds_train_test = zeros(type, n_folds, num_splits) |
| 78 | + test_set_idx = test_set_index(ccv) |
| 79 | + for (i, idx) in enumerate(test_set_idx) |
| 80 | + folds_train_test[idx, i] .= one(type) |
| 81 | + end |
| 82 | + return folds_train_test |
| 83 | +end |
| 84 | + |
| 85 | +export CombinatorialCrossValidation, n_test_paths, average_train_size, test_set_index, |
| 86 | + binary_train_test_sets |
0 commit comments