Skip to content

Commit 618a6f7

Browse files
Closes #18
1 parent 155d7a8 commit 618a6f7

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

R/bayesOpt.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,20 +236,21 @@ bayesOpt <- function(
236236

237237
# Initialization Setup
238238
if (missing(initGrid) + missing(initPoints) != 1) stop("Please provide 1 of initGrid or initPoints, but not both.")
239-
if (initPoints <= length(bounds)) stop("initPoints must be greater than the number of FUN inputs.")
240239
if (!missing(initGrid)) {
241240
setDT(initGrid)
242241
inBounds <- checkBounds(initGrid,bounds)
243242
inBounds <- as.logical(apply(inBounds,1,prod))
244243
if (any(!inBounds)) stop("initGrid not within bounds.")
245244
optObj$initPars$initialSample <- "User Provided Grid"
245+
initPoints <- nrow(initGrid)
246246
} else {
247247
initGrid <- randParams(boundsDT, initPoints)
248248
optObj$initPars$initialSample <- "Latin Hypercube Sampling"
249249
}
250250
optObj$initPars$initGrid <- initGrid
251251
if (nrow(initGrid) <= 2) stop("Cannot initialize with less than 3 samples.")
252252
optObj$initPars$initPoints <- nrow(initGrid)
253+
if (initPoints <= length(bounds)) stop("initPoints must be greater than the number of FUN inputs.")
253254

254255
# Output from FUN is sunk into a temporary file.
255256
sinkFile <- file()

tests/testthat/test-hyperparameterTuning.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@ testthat::test_that(
7575
, alpha = c(0,1)
7676
)
7777

78+
initGrid <- data.table(
79+
max_depth = c(1,1,2,2,3,3,4,4,5)
80+
, max_leaves = c(2,3,4,5,6,7,8,9,10)
81+
, min_child_weight = seq(bounds$min_child_weight[1],bounds$min_child_weight[2],length.out = 9)
82+
, subsample = seq(bounds$subsample[1],bounds$subsample[2],length.out = 9)
83+
, colsample_bytree = seq(bounds$colsample_bytree[1],bounds$colsample_bytree[2],length.out = 9)
84+
, gamma = seq(bounds$gamma[1],bounds$gamma[2],length.out = 9)
85+
, lambda = seq(bounds$lambda[1],bounds$lambda[2],length.out = 9)
86+
, alpha = seq(bounds$alpha[1],bounds$alpha[2],length.out = 9)
87+
)
88+
7889
optObj <- bayesOpt(
7990
FUN = scoringFunction
8091
, bounds = bounds
@@ -86,6 +97,17 @@ testthat::test_that(
8697

8798
expect_equal(nrow(optObj$scoreSummary),13)
8899

100+
optObj <- bayesOpt(
101+
FUN = scoringFunction
102+
, bounds = bounds
103+
, initGrid = initGrid
104+
, iters.n = 4
105+
, iters.k = 1
106+
, gsPoints = 10
107+
)
108+
109+
expect_equal(nrow(optObj$scoreSummary),13)
110+
89111
}
90112

91113
)

0 commit comments

Comments
 (0)