@@ -130,14 +130,32 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
130130 if (! is.null(self $ input_trafo )) {
131131 xdt = self $ input_trafo $ transform(xdt )
132132 }
133+
134+ # speeding up some checks by constructing the predict task directly instead of relying on predict_newdata
133135 preds = lapply(self $ learner , function (learner ) {
134- pred = learner $ predict_newdata(newdata = xdt )
136+ task = learner $ state $ train_task $ clone()
137+ set(xdt , j = task $ target_names , value = NA_real_ ) # tasks only have features and the target but we have to set the target to NA
138+ newdata = as_data_backend(xdt )
139+ task $ backend = newdata
140+ task $ row_roles $ use = task $ backend $ rownames
141+ pred = learner $ predict(task )
135142 if (learner $ predict_type == " se" ) {
136143 data.table(mean = pred $ response , se = pred $ se )
137144 } else {
138145 data.table(mean = pred $ response )
139146 }
140147 })
148+
149+ # slow
150+ # preds = lapply(self$learner, function(learner) {
151+ # pred = learner$predict_newdata(newdata = xdt)
152+ # if (learner$predict_type == "se") {
153+ # data.table(mean = pred$response, se = pred$se)
154+ # } else {
155+ # data.table(mean = pred$response)
156+ # }
157+ # })
158+
141159 names(preds ) = names(self $ learner )
142160 if (! is.null(self $ output_trafo ) && self $ output_trafo $ invert_posterior ) {
143161 preds = self $ output_trafo $ inverse_transform_posterior(preds )
0 commit comments