Skip to content

Commit 6077e3e

Browse files
felixcheungFelix Cheung
authored andcommitted
[SPARK-21801][SPARKR][TEST] unit test randomly fail with randomforest
## What changes were proposed in this pull request? fix the random seed to eliminate variability ## How was this patch tested? jenkins, appveyor, lots more jenkins Author: Felix Cheung <[email protected]> Closes apache#19018 from felixcheung/rrftest.
1 parent 6327ea5 commit 6077e3e

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

R/pkg/tests/fulltests/test_mllib_tree.R

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ test_that("spark.gbt", {
6666
# label must be binary - GBTClassifier currently only supports binary classification.
6767
iris2 <- iris[iris$Species != "virginica", ]
6868
data <- suppressWarnings(createDataFrame(iris2))
69-
model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification")
69+
model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification", seed = 12)
7070
stats <- summary(model)
7171
expect_equal(stats$numFeatures, 2)
7272
expect_equal(stats$numTrees, 20)
@@ -94,7 +94,7 @@ test_that("spark.gbt", {
9494

9595
iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1)
9696
df <- suppressWarnings(createDataFrame(iris2))
97-
m <- spark.gbt(df, NumericSpecies ~ ., type = "classification")
97+
m <- spark.gbt(df, NumericSpecies ~ ., type = "classification", seed = 12)
9898
s <- summary(m)
9999
# test numeric prediction values
100100
expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
@@ -106,7 +106,7 @@ test_that("spark.gbt", {
106106
if (windows_with_hadoop()) {
107107
data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
108108
source = "libsvm")
109-
model <- spark.gbt(data, label ~ features, "classification")
109+
model <- spark.gbt(data, label ~ features, "classification", seed = 12)
110110
expect_equal(summary(model)$numFeatures, 692)
111111
}
112112

@@ -117,10 +117,11 @@ test_that("spark.gbt", {
117117
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
118118
traindf <- as.DataFrame(data[trainidxs, ])
119119
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
120-
model <- spark.gbt(traindf, clicked ~ ., type = "classification")
120+
model <- spark.gbt(traindf, clicked ~ ., type = "classification", seed = 23)
121121
predictions <- predict(model, testdf)
122122
expect_error(collect(predictions))
123-
model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep")
123+
model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep",
124+
seed = 23)
124125
predictions <- predict(model, testdf)
125126
expect_equal(class(collect(predictions)$clicked[1]), "character")
126127
})
@@ -129,7 +130,7 @@ test_that("spark.randomForest", {
129130
# regression
130131
data <- suppressWarnings(createDataFrame(longley))
131132
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
132-
numTrees = 1)
133+
numTrees = 1, seed = 1)
133134

134135
predictions <- collect(predict(model, data))
135136
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
@@ -177,7 +178,7 @@ test_that("spark.randomForest", {
177178
# classification
178179
data <- suppressWarnings(createDataFrame(iris))
179180
model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification",
180-
maxDepth = 5, maxBins = 16)
181+
maxDepth = 5, maxBins = 16, seed = 123)
181182

182183
stats <- summary(model)
183184
expect_equal(stats$numFeatures, 2)
@@ -215,7 +216,7 @@ test_that("spark.randomForest", {
215216
iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
216217
data <- suppressWarnings(createDataFrame(iris[-5]))
217218
model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification",
218-
maxDepth = 5, maxBins = 16)
219+
maxDepth = 5, maxBins = 16, seed = 123)
219220
stats <- summary(model)
220221
expect_equal(stats$numFeatures, 2)
221222
expect_equal(stats$numTrees, 20)
@@ -234,28 +235,29 @@ test_that("spark.randomForest", {
234235
traindf <- as.DataFrame(data[trainidxs, ])
235236
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
236237
model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
237-
maxDepth = 10, maxBins = 10, numTrees = 10)
238+
maxDepth = 10, maxBins = 10, numTrees = 10, seed = 123)
238239
predictions <- predict(model, testdf)
239240
expect_error(collect(predictions))
240241
model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
241242
maxDepth = 10, maxBins = 10, numTrees = 10,
242-
handleInvalid = "keep")
243+
handleInvalid = "keep", seed = 123)
243244
predictions <- predict(model, testdf)
244245
expect_equal(class(collect(predictions)$clicked[1]), "character")
245246

246247
# spark.randomForest classification can work on libsvm data
247248
if (windows_with_hadoop()) {
248249
data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
249250
source = "libsvm")
250-
model <- spark.randomForest(data, label ~ features, "classification")
251+
model <- spark.randomForest(data, label ~ features, "classification", seed = 123)
251252
expect_equal(summary(model)$numFeatures, 4)
252253
}
253254
})
254255

255256
test_that("spark.decisionTree", {
256257
# regression
257258
data <- suppressWarnings(createDataFrame(longley))
258-
model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16)
259+
model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
260+
seed = 42)
259261

260262
predictions <- collect(predict(model, data))
261263
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
@@ -288,7 +290,7 @@ test_that("spark.decisionTree", {
288290
# classification
289291
data <- suppressWarnings(createDataFrame(iris))
290292
model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification",
291-
maxDepth = 5, maxBins = 16)
293+
maxDepth = 5, maxBins = 16, seed = 43)
292294

293295
stats <- summary(model)
294296
expect_equal(stats$numFeatures, 2)
@@ -325,7 +327,7 @@ test_that("spark.decisionTree", {
325327
iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
326328
data <- suppressWarnings(createDataFrame(iris[-5]))
327329
model <- spark.decisionTree(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification",
328-
maxDepth = 5, maxBins = 16)
330+
maxDepth = 5, maxBins = 16, seed = 44)
329331
stats <- summary(model)
330332
expect_equal(stats$numFeatures, 2)
331333
expect_equal(stats$maxDepth, 5)
@@ -339,7 +341,7 @@ test_that("spark.decisionTree", {
339341
if (windows_with_hadoop()) {
340342
data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
341343
source = "libsvm")
342-
model <- spark.decisionTree(data, label ~ features, "classification")
344+
model <- spark.decisionTree(data, label ~ features, "classification", seed = 45)
343345
expect_equal(summary(model)$numFeatures, 4)
344346
}
345347

@@ -351,11 +353,11 @@ test_that("spark.decisionTree", {
351353
traindf <- as.DataFrame(data[trainidxs, ])
352354
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
353355
model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
354-
maxDepth = 5, maxBins = 16)
356+
maxDepth = 5, maxBins = 16, seed = 46)
355357
predictions <- predict(model, testdf)
356358
expect_error(collect(predictions))
357359
model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
358-
maxDepth = 5, maxBins = 16, handleInvalid = "keep")
360+
maxDepth = 5, maxBins = 16, handleInvalid = "keep", seed = 46)
359361
predictions <- predict(model, testdf)
360362
expect_equal(class(collect(predictions)$clicked[1]), "character")
361363
})

0 commit comments

Comments
 (0)