Skip to content

Commit 0cbfda7

Browse files
authored
feat: start remote workers with mirai (#276)
* ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ...
1 parent a46dd21 commit 0cbfda7

File tree

6 files changed

+79
-50
lines changed

6 files changed

+79
-50
lines changed

DESCRIPTION

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,15 @@ Suggests:
3838
GenSA,
3939
irace (>= 4.0.0),
4040
knitr,
41+
mirai,
4142
nloptr,
4243
progressr,
4344
processx,
4445
redux,
4546
testthat (>= 3.0.0),
4647
rush (>= 0.1.2)
48+
Remotes:
49+
mlr-org/rush@mirai
4750
Config/testthat/edition: 3
4851
Config/testthat/parallel: false
4952
Encoding: UTF-8

R/Objective.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@ Objective = R6Class("Objective",
183183
man = function(rhs) {
184184
assert_ro_binding(rhs)
185185
private$.man
186+
},
187+
188+
#' @field packages (`character()`)\cr
189+
#' Set of required packages.
190+
packages = function(rhs) {
191+
assert_ro_binding(rhs)
192+
private$.packages
186193
}
187194
),
188195

@@ -211,6 +218,7 @@ Objective = R6Class("Objective",
211218
},
212219

213220
.label = NULL,
214-
.man = NULL
221+
.man = NULL,
222+
.packages = NULL
215223
)
216224
)

R/OptimizerAsync.R

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ OptimizerAsync = R6Class("OptimizerAsync",
5454
#' @keywords internal
5555
#' @export
5656
optimize_async_default = function(instance, optimizer, design = NULL, n_workers = NULL) {
57-
assert_class(instance, "OptimInstanceAsync")
58-
assert_class(optimizer, "OptimizerAsync")
5957
assert_data_table(design, null.ok = TRUE)
6058

6159
instance$archive$start_time = Sys.time()
@@ -80,45 +78,50 @@ optimize_async_default = function(instance, optimizer, design = NULL, n_workers
8078
} else {
8179
# run .optimize() on workers
8280
rush = instance$rush
81+
worker_type = rush::rush_config()$worker_type %??% "local"
8382

84-
# FIXME: How to pass globals and packages?
85-
if (rush$n_pre_workers) {
86-
# start remote workers
87-
lg$info("Starting to optimize %i parameter(s) with '%s' and '%s' on %i remote worker(s)",
88-
instance$search_space$length,
89-
optimizer$format(),
90-
instance$terminator$format(with_params = TRUE),
91-
rush$n_pre_workers
92-
)
93-
83+
if (worker_type == "script") {
84+
# worker script
85+
rush$worker_script(
86+
worker_loop = bbotk_worker_loop,
87+
packages = c(optimizer$packages, instance$objective$packages, "bbotk"),
88+
optimizer = optimizer,
89+
instance = instance)
90+
} else if (worker_type == "remote") {
91+
# remote workers
9492
rush$start_remote_workers(
9593
worker_loop = bbotk_worker_loop,
96-
packages = c(optimizer$packages, "bbotk"), # add packages from objective
94+
packages = c(optimizer$packages, instance$objective$packages, "bbotk"),
9795
optimizer = optimizer,
9896
instance = instance)
99-
} else if (rush::rush_available()) {
97+
} else if (worker_type == "local") {
10098
# local workers
101-
lg$info("Starting to optimize %i parameter(s) with '%s' and '%s' on %i remote worker(s)",
102-
instance$search_space$length,
103-
optimizer$format(),
104-
instance$terminator$format(with_params = TRUE),
105-
rush::rush_config()$n_workers
106-
)
107-
10899
rush$start_local_workers(
109100
worker_loop = bbotk_worker_loop,
110-
packages = c(optimizer$packages, "bbotk"), # add packages from objective
101+
packages = c(optimizer$packages, instance$objective$packages, "bbotk"),
111102
optimizer = optimizer,
112103
instance = instance)
113-
} else {
114-
stop("No rush plan available to start local workers and no pre-started remote workers found. See `?rush::rush_plan()`.")
115104
}
116105
}
117106

107+
lg$info("Starting to optimize %i parameter(s) with '%s' and '%s' on %s %s worker(s)",
108+
instance$search_space$length,
109+
optimizer$format(),
110+
instance$terminator$format(with_params = TRUE),
111+
as.character(rush::rush_config()$n_workers %??% ""),
112+
worker_type)
113+
114+
n_running_workers = 0
118115
# wait until optimization is finished
119116
# check terminated workers when the terminator is "none"
120117
while(TRUE) {
121118
Sys.sleep(1)
119+
120+
if (rush$n_running_workers > n_running_workers) {
121+
n_running_workers = rush$n_running_workers
122+
lg$info("%i worker(s) started", n_running_workers)
123+
}
124+
122125
instance$rush$print_log()
123126

124127
# fetch new results for printing
@@ -133,7 +136,10 @@ optimize_async_default = function(instance, optimizer, design = NULL, n_workers
133136
}
134137

135138
if (instance$is_terminated) break
136-
if (instance$rush$all_workers_terminated) break
139+
if (instance$rush$all_workers_terminated) {
140+
lg$info("All workers have terminated.")
141+
break
142+
}
137143
}
138144

139145
# assign result

man/Objective.Rd

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_OptimInstanceAsyncSingleCrit.R

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,6 @@ test_that("reconnect method works", {
7979
skip_if_not_installed("rush")
8080
flush_redis()
8181

82-
on.exit({
83-
file.remove("instance.rds")
84-
})
85-
8682
rush::rush_plan(n_workers = 2)
8783

8884
instance = oi_async(
@@ -94,8 +90,9 @@ test_that("reconnect method works", {
9490
optimizer = opt("async_random_search")
9591
optimizer$optimize(instance)
9692

97-
saveRDS(instance, file = "instance.rds")
98-
instance = readRDS("instance.rds")
93+
file = tempfile(fileext = ".rds")
94+
saveRDS(instance, file = file )
95+
instance = readRDS(file)
9996

10097
instance$reconnect()
10198

tests/testthat/test_OptimizerAsync.R

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,50 +9,62 @@ test_that("OptimizerAsync starts local workers", {
99
instance = oi_async(
1010
objective = OBJ_2D,
1111
search_space = PS_2D,
12-
terminator = trm("evals", n_evals = 5L),
12+
terminator = trm("evals", n_evals = 50L),
1313
)
1414

1515
optimizer = opt("async_random_search")
1616
optimizer$optimize(instance)
1717

1818
expect_data_table(instance$rush$worker_info, nrows = 2)
19+
expect_list(instance$rush$processes_processx, len = 2)
1920

2021
expect_rush_reset(instance$rush)
2122
})
2223

2324
test_that("OptimizerAsync starts remote workers", {
2425
skip_on_cran()
25-
skip_if_not_installed("rush")
26-
skip_if_not_installed("processx")
26+
skip_if_not_installed(c("rush", "mirai"))
2727
flush_redis()
28-
library(processx)
28+
library(rush)
29+
30+
mirai::daemons(2)
31+
32+
rush_plan(n_workers = 2, worker_type = "remote")
33+
34+
instance = oi_async(
35+
objective = OBJ_2D,
36+
search_space = PS_2D,
37+
terminator = trm("evals", n_evals = 5L),
38+
)
39+
40+
optimizer = opt("async_random_search")
41+
optimizer$optimize(instance)
2942

30-
rush = rsh(network_id = "test_rush")
31-
expect_snapshot(rush$create_worker_script())
43+
expect_data_table(instance$rush$worker_info, nrows = 2)
44+
expect_list(instance$rush$processes_mirai, len = 2)
3245

33-
px = process$new("Rscript",
34-
args = c("-e", 'rush::start_worker(network_id = "test_rush", remote = TRUE, url = "redis://127.0.0.1:6379", scheme = "redis", host = "127.0.0.1", port = "6379")'),
35-
supervise = TRUE,
36-
stderr = "|", stdout = "|")
46+
expect_rush_reset(instance$rush)
47+
mirai::daemons(0)
48+
})
3749

38-
on.exit({
39-
px$kill()
40-
}, add = TRUE)
50+
test_that("OptimizerAsync defaults to local worker", {
51+
skip_on_cran()
52+
skip_if_not_installed("rush")
53+
flush_redis()
54+
library(rush)
4155

42-
Sys.sleep(5)
56+
rush_plan(n_workers = 2)
4357

4458
instance = oi_async(
4559
objective = OBJ_2D,
4660
search_space = PS_2D,
47-
terminator = trm("evals", n_evals = 5L),
48-
rush = rush
61+
terminator = trm("evals", n_evals = 50L),
4962
)
5063

5164
optimizer = opt("async_random_search")
5265
optimizer$optimize(instance)
5366

54-
expect_data_table(instance$rush$worker_info, nrows = 1)
55-
expect_true(instance$rush$worker_info$remote)
67+
expect_data_table(instance$rush$worker_info, nrows = 2)
5668

5769
expect_rush_reset(instance$rush)
5870
})

0 commit comments

Comments
 (0)