diff --git a/R/updateParVals.R b/R/updateParVals.R index 077c04c5..b6dab6f2 100644 --- a/R/updateParVals.R +++ b/R/updateParVals.R @@ -18,27 +18,52 @@ updateParVals = function(par.set, old.par.vals, new.par.vals, warn = FALSE) { assertList(old.par.vals, names = "named") assertList(new.par.vals, names = "named") assertClass(par.set, "ParamSet") - - #we might want to check requires with defaults that are not overwritten by new.par.vals - usable.defaults = getDefaults(par.set) - usable.defaults = usable.defaults[names(usable.defaults) %nin% names(new.par.vals)] + assertFlag(warn) + default.par.vals = getDefaults(par.set) + # First we extend both par.vals lists with the defaults to get the fully requirements meeting par.vals lists + old.with.defaults = updateParVals2(par.set = par.set, old.par.vals = default.par.vals, new.par.vals = old.par.vals) + updated.old = attr(old.with.defaults, "updated") + new.with.defaults = updateParVals2(par.set = par.set, old.par.vals = default.par.vals, new.par.vals = new.par.vals) + updated.new = attr(new.with.defaults, "updated") + # new.candidates are the par.vals we want to use from new.with.defaults. + # These exclude those values where the defaults got updated by the old.par.vals but the update is not present in the new.par.vals but still we want to keep this update. + # + # updated.old | updated.new | use + # T | T | new + # T | F | old + # F | T | new + # F | F | new(default) + new.candidates = names(new.with.defaults) %nin% setdiff(names(updated.old)[updated.old], names(updated.new)[updated.new]) + updated.par.vals = updateParVals2(par.set = par.set, old.par.vals = old.with.defaults, new.par.vals = new.with.defaults[new.candidates]) + # Find out which parmam names were kept in both update processes + # this indicates that this was a default and we don't need it, as it is still a valid default. + both.updated = union(names(updated.new)[updated.new], names(updated.old)[updated.old]) + result = updated.par.vals[names(updated.par.vals) %in% both.updated] + # order as in par.set + # result = result[match(names(result), getParamIds(par.set))] + if (warn) { + # detect dropped param settings: + warningf("ParamSettings (%s) were dropped.", convertToShortString(old.par.vals[names(old.par.vals) %nin% names(result)])) + } + return(result) +} + +updateParVals2 = function(par.set, old.par.vals, new.par.vals) { + updated = setNames(rep(TRUE, length(new.par.vals)), names(new.par.vals)) repeat { - # we repeat to include parameters which depend on each other by requirements - # candidates are params of the old par.vals we might still need. + # we include parameters of the old.par.vals if they meet the requirements + # we repeat because some parameters of old.par.vals might only meet the requirements after we added others. (chained requirements) candidate.par.names = setdiff(names(old.par.vals), names(new.par.vals)) for (pn in candidate.par.names) { # If all requirement parameters for the candidate are in the new.par.vals and if the requirements are met - if (all(getRequiredParamNames(par.set$pars[[pn]]) %in% names(new.par.vals)) && requiresOk(par.set$pars[[pn]], new.par.vals)) { - new.par.vals[pn] = old.par.vals[pn] # keep old.par.val in new.par.vals as it meets the requirements - } else if (all(getRequiredParamNames(par.set$pars[[pn]]) %in% names(usable.defaults)) && requiresOk(par.set$pars[[pn]], usable.defaults)) { - new.par.vals[pn] = old.par.vals[pn] # keep old.par.val as it meets requirement via defaults - } else if (warn){ - # otherwise we can drop the old par.val because it does not meet the requirements. - warningf("ParamSetting %s was dropped.", convertToShortString(old.par.vals[pn])) + if (isTRUE(try(requiresOk(par.set$pars[[pn]], c(new.par.vals, old.par.vals[pn])), silent = TRUE))) { + # keep old.par.val in new.par.vals as it meets the requirements + new.par.vals[pn] = old.par.vals[pn] + updated[pn] = FALSE } } # break if no changes were made if (identical(candidate.par.names, setdiff(names(old.par.vals), names(new.par.vals)))) break } - return(new.par.vals) -} + setAttribute(new.par.vals, "updated", updated) +} \ No newline at end of file diff --git a/tests/testthat/test_updateParVals.R b/tests/testthat/test_updateParVals.R index 8c50341c..041a5088 100644 --- a/tests/testthat/test_updateParVals.R +++ b/tests/testthat/test_updateParVals.R @@ -13,7 +13,11 @@ test_that("updateParVals works", { makeLogicalParam("f", default = TRUE)) pc = updateParVals(ps, pa, pb) expect_equal(pc, list(a = 0, c = 3, d = 4, e = 5)) - expect_warning(updateParVals(ps, pa, pb, warn = TRUE), "ParamSetting b=2") + expect_warning(updateParVals(ps, pa, pb, warn = TRUE), "ParamSettings \\(b=2\\)") + + pb2 = list(a = 0, f = FALSE) + pc2 = updateParVals(ps, pa, pb2) + expect_equal(pc2, list(a = 0, f = FALSE, d = 4)) pb2 = list(a = 0, f = FALSE) pc2 = updateParVals(ps, pa, pb2) @@ -44,4 +48,28 @@ test_that("updateParVals works", { pb = list(b = 3) pc = updateParVals(ps, pa, pb) expect_equal(pc, list(b = 3, c = TRUE)) + + #more complicated stuff + ps = makeParamSet( + makeDiscreteLearnerParam(id = "a", default = "a2", + values = c("a1", "a2", "a3"), + requires = quote(!a %in% c("a2") || b == TRUE)), + makeLogicalLearnerParam(id = "b", default = FALSE, tunable = FALSE) + ) + pa = list(a = "a1") + pb = list() + pc = updateParVals(ps, pa, pb) + expect_equal(pc, pa) + + ps = makeParamSet( + makeIntegerParam("a", default = 10L) + ) + pa = list(a = 0L) + pb = list() + pc = updateParVals(ps, pa, pb) + expect_equal(pc, pa) + pb2 = list(a = 5L) + pc2 = updateParVals(ps, pa, pb2) + expect_equal(pc2, pb2) + })