|
| 1 | +#' kfold_xval model class |
| 2 | +#' |
| 3 | +#' Applies k-fold crossvalidation to a model or model.seq() |
| 4 | +#' @export kfold_xval2 |
| 5 | +#' @examples |
| 6 | +#' I = kfold_xval2() |
| 7 | +kfold_xval2<-setClass( |
| 8 | + "kfold_xval2", |
| 9 | + contains='resampler', |
| 10 | + slots=c(params.folds='numeric', |
| 11 | + params.method='character', |
| 12 | + params.factor_name='entity', |
| 13 | + outputs.metric='data.frame' |
| 14 | + ), |
| 15 | + prototype = list(name='k-fold cross-validation', |
| 16 | + type="resampling", |
| 17 | + result='metric', |
| 18 | + params.folds=10, |
| 19 | + params.method='venetian' |
| 20 | + ) |
| 21 | +) |
| 22 | + |
| 23 | +#' @export |
| 24 | +#' @template run |
| 25 | +setMethod(f="run", |
| 26 | + signature=c("kfold_xval2",'dataset','metric'), |
| 27 | + definition=function(I,D,MET=NULL) |
| 28 | + { |
| 29 | + X=dataset.data(D) |
| 30 | + |
| 31 | + |
| 32 | + WF=models(I) |
| 33 | + |
| 34 | + # venetian 123123123123 |
| 35 | + if (param.value(I,'method')=='venetian') |
| 36 | + { |
| 37 | + fold_id=rep(1:param.value(I,'folds'),length.out=nrow(X)) |
| 38 | + } else if (param.value(I,'method')=='blocks') |
| 39 | + { # blocks 111122223333 |
| 40 | + fold_id=rep(1:param.value(I,'folds'),length.out=nrow(X)) |
| 41 | + fold_id=sort(fold_id) |
| 42 | + } else if (param.value(I,'method')=='random') { |
| 43 | + fold_id=rep(1:param.value(I,'folds'),length.out=nrow(X)) |
| 44 | + fold_id=sample(fold_id,length(fold_id),replace = FALSE) |
| 45 | + } else { |
| 46 | + stop('unknown method for cross-validation. (try "venetian", "blocks" or "random")') |
| 47 | + } |
| 48 | + |
| 49 | + # for each value of k, split the data and run the workflow |
| 50 | + for (i in 1:param.value(I,'folds')) |
| 51 | + { |
| 52 | + # prep the training data |
| 53 | + TrainX=X[fold_id!=i,,drop=FALSE] |
| 54 | + TrainY=Y[fold_id!=i,,drop=FALSE] |
| 55 | + dtrain=dataset(data=TrainX,sample_meta=TrainY) |
| 56 | + |
| 57 | + TestX=X[fold_id==i,,drop=FALSE] |
| 58 | + TestY=Y[fold_id==i,,drop=FALSE] |
| 59 | + dtest=dataset(data=TestX,sample_meta=TestY) |
| 60 | + |
| 61 | + if (is(WF,'model_OR_model.seq')) |
| 62 | + # HAS TO BE A model OR model.seq |
| 63 | + { |
| 64 | + WF=model.train(WF,dtrain) |
| 65 | + # apply the model |
| 66 | + WF=model.predict(WF,dtrain) |
| 67 | + p=predicted(WF) |
| 68 | + # metric |
| 69 | + if (MET@actual=='sample_meta') { |
| 70 | + yhat=p |
| 71 | + } else if (MET@actual=='data') { |
| 72 | + yhat=p$data |
| 73 | + } else { |
| 74 | + stop('MET$actual not implemented yet') |
| 75 | + } |
| 76 | + YHATtr[fold_id!=i,]=yhat |
| 77 | + |
| 78 | + # test set |
| 79 | + WF=model.predict(WF,dtest) |
| 80 | + p=predicted(WF) |
| 81 | + |
| 82 | + if (MET@actual=='sample_meta') { |
| 83 | + yhat=p |
| 84 | + } else if (MET@actual=='data') { |
| 85 | + yhat=p$data |
| 86 | + } else { |
| 87 | + stop('MET$actual not implemented yet') |
| 88 | + } |
| 89 | + YHAT[fold_id==i,]=yhat |
| 90 | + |
| 91 | + |
| 92 | + } else if (is(WF,'iterator')) |
| 93 | + { |
| 94 | + stop('not implemented yet') |
| 95 | + } |
| 96 | + # validation set...?? |
| 97 | + # WF=predict(WF,dval) |
| 98 | + # p=predicted(WF[length(WF)]) |
| 99 | + # val_result[,1]=p[,1] |
| 100 | + |
| 101 | + #all_results[((nrow(X)*(i-1))+1):(nrow(X)*i),]=fold_results |
| 102 | + } |
| 103 | + |
| 104 | + if (MET@actual=='data') { |
| 105 | + # if its a model sequence get the prediction from the penultimate step |
| 106 | + # for comparison with the predictions |
| 107 | + if (is(WF,'model_OR_model.seq')) { |
| 108 | + # apply model to data |
| 109 | + WF=model.apply(WF,D) |
| 110 | + n=length(WF) |
| 111 | + if (n>1) {# just in case a sequence of 1 |
| 112 | + Y=predicted(WF[n-1])$data |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + # test sets metric |
| 118 | + df=data.frame('training_set'=0,'test_set'=0,'metric'=class(MET)[[1]]) |
| 119 | + MET=calculate(MET,Y,YHAT) |
| 120 | + df$training_set=value(MET) |
| 121 | + # training set metric |
| 122 | + MET=calculate(MET,Y,YHATtr) |
| 123 | + df$test_set=value(MET) |
| 124 | + I$metric=df |
| 125 | + return(I) |
| 126 | + } |
| 127 | +) |
| 128 | + |
0 commit comments