@@ -66,7 +66,7 @@ test_that("spark.gbt", {
66
66
# label must be binary - GBTClassifier currently only supports binary classification.
67
67
iris2 <- iris [iris $ Species != " virginica" , ]
68
68
data <- suppressWarnings(createDataFrame(iris2 ))
69
- model <- spark.gbt(data , Species ~ Petal_Length + Petal_Width , " classification" )
69
+ model <- spark.gbt(data , Species ~ Petal_Length + Petal_Width , " classification" , seed = 12 )
70
70
stats <- summary(model )
71
71
expect_equal(stats $ numFeatures , 2 )
72
72
expect_equal(stats $ numTrees , 20 )
@@ -94,7 +94,7 @@ test_that("spark.gbt", {
94
94
95
95
iris2 $ NumericSpecies <- ifelse(iris2 $ Species == " setosa" , 0 , 1 )
96
96
df <- suppressWarnings(createDataFrame(iris2 ))
97
- m <- spark.gbt(df , NumericSpecies ~ . , type = " classification" )
97
+ m <- spark.gbt(df , NumericSpecies ~ . , type = " classification" , seed = 12 )
98
98
s <- summary(m )
99
99
# test numeric prediction values
100
100
expect_equal(iris2 $ NumericSpecies , as.double(collect(predict(m , df ))$ prediction ))
@@ -106,7 +106,7 @@ test_that("spark.gbt", {
106
106
if (windows_with_hadoop()) {
107
107
data <- read.df(absoluteSparkPath(" data/mllib/sample_binary_classification_data.txt" ),
108
108
source = " libsvm" )
109
- model <- spark.gbt(data , label ~ features , " classification" )
109
+ model <- spark.gbt(data , label ~ features , " classification" , seed = 12 )
110
110
expect_equal(summary(model )$ numFeatures , 692 )
111
111
}
112
112
@@ -117,10 +117,11 @@ test_that("spark.gbt", {
117
117
trainidxs <- base :: sample(nrow(data ), nrow(data ) * 0.7 )
118
118
traindf <- as.DataFrame(data [trainidxs , ])
119
119
testdf <- as.DataFrame(rbind(data [- trainidxs , ], c(0 , " the other" )))
120
- model <- spark.gbt(traindf , clicked ~ . , type = " classification" )
120
+ model <- spark.gbt(traindf , clicked ~ . , type = " classification" , seed = 23 )
121
121
predictions <- predict(model , testdf )
122
122
expect_error(collect(predictions ))
123
- model <- spark.gbt(traindf , clicked ~ . , type = " classification" , handleInvalid = " keep" )
123
+ model <- spark.gbt(traindf , clicked ~ . , type = " classification" , handleInvalid = " keep" ,
124
+ seed = 23 )
124
125
predictions <- predict(model , testdf )
125
126
expect_equal(class(collect(predictions )$ clicked [1 ]), " character" )
126
127
})
@@ -129,7 +130,7 @@ test_that("spark.randomForest", {
129
130
# regression
130
131
data <- suppressWarnings(createDataFrame(longley ))
131
132
model <- spark.randomForest(data , Employed ~ . , " regression" , maxDepth = 5 , maxBins = 16 ,
132
- numTrees = 1 )
133
+ numTrees = 1 , seed = 1 )
133
134
134
135
predictions <- collect(predict(model , data ))
135
136
expect_equal(predictions $ prediction , c(60.323 , 61.122 , 60.171 , 61.187 ,
@@ -177,7 +178,7 @@ test_that("spark.randomForest", {
177
178
# classification
178
179
data <- suppressWarnings(createDataFrame(iris ))
179
180
model <- spark.randomForest(data , Species ~ Petal_Length + Petal_Width , " classification" ,
180
- maxDepth = 5 , maxBins = 16 )
181
+ maxDepth = 5 , maxBins = 16 , seed = 123 )
181
182
182
183
stats <- summary(model )
183
184
expect_equal(stats $ numFeatures , 2 )
@@ -215,7 +216,7 @@ test_that("spark.randomForest", {
215
216
iris $ NumericSpecies <- lapply(iris $ Species , labelToIndex )
216
217
data <- suppressWarnings(createDataFrame(iris [- 5 ]))
217
218
model <- spark.randomForest(data , NumericSpecies ~ Petal_Length + Petal_Width , " classification" ,
218
- maxDepth = 5 , maxBins = 16 )
219
+ maxDepth = 5 , maxBins = 16 , seed = 123 )
219
220
stats <- summary(model )
220
221
expect_equal(stats $ numFeatures , 2 )
221
222
expect_equal(stats $ numTrees , 20 )
@@ -234,28 +235,29 @@ test_that("spark.randomForest", {
234
235
traindf <- as.DataFrame(data [trainidxs , ])
235
236
testdf <- as.DataFrame(rbind(data [- trainidxs , ], c(0 , " the other" )))
236
237
model <- spark.randomForest(traindf , clicked ~ . , type = " classification" ,
237
- maxDepth = 10 , maxBins = 10 , numTrees = 10 )
238
+ maxDepth = 10 , maxBins = 10 , numTrees = 10 , seed = 123 )
238
239
predictions <- predict(model , testdf )
239
240
expect_error(collect(predictions ))
240
241
model <- spark.randomForest(traindf , clicked ~ . , type = " classification" ,
241
242
maxDepth = 10 , maxBins = 10 , numTrees = 10 ,
242
- handleInvalid = " keep" )
243
+ handleInvalid = " keep" , seed = 123 )
243
244
predictions <- predict(model , testdf )
244
245
expect_equal(class(collect(predictions )$ clicked [1 ]), " character" )
245
246
246
247
# spark.randomForest classification can work on libsvm data
247
248
if (windows_with_hadoop()) {
248
249
data <- read.df(absoluteSparkPath(" data/mllib/sample_multiclass_classification_data.txt" ),
249
250
source = " libsvm" )
250
- model <- spark.randomForest(data , label ~ features , " classification" )
251
+ model <- spark.randomForest(data , label ~ features , " classification" , seed = 123 )
251
252
expect_equal(summary(model )$ numFeatures , 4 )
252
253
}
253
254
})
254
255
255
256
test_that(" spark.decisionTree" , {
256
257
# regression
257
258
data <- suppressWarnings(createDataFrame(longley ))
258
- model <- spark.decisionTree(data , Employed ~ . , " regression" , maxDepth = 5 , maxBins = 16 )
259
+ model <- spark.decisionTree(data , Employed ~ . , " regression" , maxDepth = 5 , maxBins = 16 ,
260
+ seed = 42 )
259
261
260
262
predictions <- collect(predict(model , data ))
261
263
expect_equal(predictions $ prediction , c(60.323 , 61.122 , 60.171 , 61.187 ,
@@ -288,7 +290,7 @@ test_that("spark.decisionTree", {
288
290
# classification
289
291
data <- suppressWarnings(createDataFrame(iris ))
290
292
model <- spark.decisionTree(data , Species ~ Petal_Length + Petal_Width , " classification" ,
291
- maxDepth = 5 , maxBins = 16 )
293
+ maxDepth = 5 , maxBins = 16 , seed = 43 )
292
294
293
295
stats <- summary(model )
294
296
expect_equal(stats $ numFeatures , 2 )
@@ -325,7 +327,7 @@ test_that("spark.decisionTree", {
325
327
iris $ NumericSpecies <- lapply(iris $ Species , labelToIndex )
326
328
data <- suppressWarnings(createDataFrame(iris [- 5 ]))
327
329
model <- spark.decisionTree(data , NumericSpecies ~ Petal_Length + Petal_Width , " classification" ,
328
- maxDepth = 5 , maxBins = 16 )
330
+ maxDepth = 5 , maxBins = 16 , seed = 44 )
329
331
stats <- summary(model )
330
332
expect_equal(stats $ numFeatures , 2 )
331
333
expect_equal(stats $ maxDepth , 5 )
@@ -339,7 +341,7 @@ test_that("spark.decisionTree", {
339
341
if (windows_with_hadoop()) {
340
342
data <- read.df(absoluteSparkPath(" data/mllib/sample_multiclass_classification_data.txt" ),
341
343
source = " libsvm" )
342
- model <- spark.decisionTree(data , label ~ features , " classification" )
344
+ model <- spark.decisionTree(data , label ~ features , " classification" , seed = 45 )
343
345
expect_equal(summary(model )$ numFeatures , 4 )
344
346
}
345
347
@@ -351,11 +353,11 @@ test_that("spark.decisionTree", {
351
353
traindf <- as.DataFrame(data [trainidxs , ])
352
354
testdf <- as.DataFrame(rbind(data [- trainidxs , ], c(0 , " the other" )))
353
355
model <- spark.decisionTree(traindf , clicked ~ . , type = " classification" ,
354
- maxDepth = 5 , maxBins = 16 )
356
+ maxDepth = 5 , maxBins = 16 , seed = 46 )
355
357
predictions <- predict(model , testdf )
356
358
expect_error(collect(predictions ))
357
359
model <- spark.decisionTree(traindf , clicked ~ . , type = " classification" ,
358
- maxDepth = 5 , maxBins = 16 , handleInvalid = " keep" )
360
+ maxDepth = 5 , maxBins = 16 , handleInvalid = " keep" , seed = 46 )
359
361
predictions <- predict(model , testdf )
360
362
expect_equal(class(collect(predictions )$ clicked [1 ]), " character" )
361
363
})
0 commit comments