Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions R/updateParVals.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,52 @@ updateParVals = function(par.set, old.par.vals, new.par.vals, warn = FALSE) {
assertList(old.par.vals, names = "named")
assertList(new.par.vals, names = "named")
assertClass(par.set, "ParamSet")

#we might want to check requires with defaults that are not overwritten by new.par.vals
usable.defaults = getDefaults(par.set)
usable.defaults = usable.defaults[names(usable.defaults) %nin% names(new.par.vals)]
assertFlag(warn)
default.par.vals = getDefaults(par.set)
# First we extend both par.vals lists with the defaults to get the fully requirements meeting par.vals lists
old.with.defaults = updateParVals2(par.set = par.set, old.par.vals = default.par.vals, new.par.vals = old.par.vals)
updated.old = attr(old.with.defaults, "updated")
new.with.defaults = updateParVals2(par.set = par.set, old.par.vals = default.par.vals, new.par.vals = new.par.vals)
updated.new = attr(new.with.defaults, "updated")
# new.candidates are the par.vals we want to use from new.with.defaults.
# These exclude those values where the defaults got updated by the old.par.vals but the update is not present in the new.par.vals but still we want to keep this update.
#
# updated.old | updated.new | use
# T | T | new
# T | F | old
# F | T | new
# F | F | new(default)
new.candidates = names(new.with.defaults) %nin% setdiff(names(updated.old)[updated.old], names(updated.new)[updated.new])
updated.par.vals = updateParVals2(par.set = par.set, old.par.vals = old.with.defaults, new.par.vals = new.with.defaults[new.candidates])
# Find out which parmam names were kept in both update processes
# this indicates that this was a default and we don't need it, as it is still a valid default.
both.updated = union(names(updated.new)[updated.new], names(updated.old)[updated.old])
result = updated.par.vals[names(updated.par.vals) %in% both.updated]
# order as in par.set
# result = result[match(names(result), getParamIds(par.set))]
if (warn) {
# detect dropped param settings:
warningf("ParamSettings (%s) were dropped.", convertToShortString(old.par.vals[names(old.par.vals) %nin% names(result)]))
}
return(result)
}

updateParVals2 = function(par.set, old.par.vals, new.par.vals) {
updated = setNames(rep(TRUE, length(new.par.vals)), names(new.par.vals))
repeat {
# we repeat to include parameters which depend on each other by requirements
# candidates are params of the old par.vals we might still need.
# we include parameters of the old.par.vals if they meet the requirements
# we repeat because some parameters of old.par.vals might only meet the requirements after we added others. (chained requirements)
candidate.par.names = setdiff(names(old.par.vals), names(new.par.vals))
for (pn in candidate.par.names) {
# If all requirement parameters for the candidate are in the new.par.vals and if the requirements are met
if (all(getRequiredParamNames(par.set$pars[[pn]]) %in% names(new.par.vals)) && requiresOk(par.set$pars[[pn]], new.par.vals)) {
new.par.vals[pn] = old.par.vals[pn] # keep old.par.val in new.par.vals as it meets the requirements
} else if (all(getRequiredParamNames(par.set$pars[[pn]]) %in% names(usable.defaults)) && requiresOk(par.set$pars[[pn]], usable.defaults)) {
new.par.vals[pn] = old.par.vals[pn] # keep old.par.val as it meets requirement via defaults
} else if (warn){
# otherwise we can drop the old par.val because it does not meet the requirements.
warningf("ParamSetting %s was dropped.", convertToShortString(old.par.vals[pn]))
if (isTRUE(try(requiresOk(par.set$pars[[pn]], c(new.par.vals, old.par.vals[pn])), silent = TRUE))) {
# keep old.par.val in new.par.vals as it meets the requirements
new.par.vals[pn] = old.par.vals[pn]
updated[pn] = FALSE
}
}
# break if no changes were made
if (identical(candidate.par.names, setdiff(names(old.par.vals), names(new.par.vals)))) break
}
return(new.par.vals)
}
setAttribute(new.par.vals, "updated", updated)
}
30 changes: 29 additions & 1 deletion tests/testthat/test_updateParVals.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ test_that("updateParVals works", {
makeLogicalParam("f", default = TRUE))
pc = updateParVals(ps, pa, pb)
expect_equal(pc, list(a = 0, c = 3, d = 4, e = 5))
expect_warning(updateParVals(ps, pa, pb, warn = TRUE), "ParamSetting b=2")
expect_warning(updateParVals(ps, pa, pb, warn = TRUE), "ParamSettings \\(b=2\\)")

pb2 = list(a = 0, f = FALSE)
pc2 = updateParVals(ps, pa, pb2)
expect_equal(pc2, list(a = 0, f = FALSE, d = 4))

pb2 = list(a = 0, f = FALSE)
pc2 = updateParVals(ps, pa, pb2)
Expand Down Expand Up @@ -44,4 +48,28 @@ test_that("updateParVals works", {
pb = list(b = 3)
pc = updateParVals(ps, pa, pb)
expect_equal(pc, list(b = 3, c = TRUE))

#more complicated stuff
ps = makeParamSet(
makeDiscreteLearnerParam(id = "a", default = "a2",
values = c("a1", "a2", "a3"),
requires = quote(!a %in% c("a2") || b == TRUE)),
makeLogicalLearnerParam(id = "b", default = FALSE, tunable = FALSE)
)
pa = list(a = "a1")
pb = list()
pc = updateParVals(ps, pa, pb)
expect_equal(pc, pa)

ps = makeParamSet(
makeIntegerParam("a", default = 10L)
)
pa = list(a = 0L)
pb = list()
pc = updateParVals(ps, pa, pb)
expect_equal(pc, pa)
pb2 = list(a = 5L)
pc2 = updateParVals(ps, pa, pb2)
expect_equal(pc2, pb2)

})