Skip to content

Commit 2bb67dc

Browse files
committed
update sdm methods: add gbm2 for Boosted Regression Trees and refine svm2 for better classification
1 parent eaf60a5 commit 2bb67dc

File tree

4 files changed

+111
-57
lines changed

4 files changed

+111
-57
lines changed

R/mod_ssdm_fit.R

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
#'
1515
#' @param sdm_method Character. A single SDM algorithm to use for fitting
1616
#' models. Valid values: "glm", "glmpoly", "gam", "glmnet/glmnet2",
17-
#' "mars/mars2", "gbm", "rf/rf2", "ranger/ranger2", "cart", "rpart", "maxent",
18-
#' "mlp", "rbf", "svm/svm2", "mda/mda2", and "fda/fda2". These correspond to
19-
#' selected methods supported by the `sdm` package. For details and supported
20-
#' options, see [sdm::getmethodNames()]. Note that some methods have
21-
#' custom implementations (e.g., "glmnet2", "mars2", "ranger2", "rf2",
22-
#' "svm2", "mda2", "fda2") to ensure consistent parameterization and
17+
#' "mars/mars2", "gbm/gbm2", "rf/rf2", "ranger/ranger2", "cart", "rpart",
18+
#' "maxent", "mlp", "rbf", "svm/svm2", "mda/mda2", and "fda/fda2". These
19+
#' correspond to selected methods supported by the `sdm` package. For details
20+
#' and supported options, see [sdm::getmethodNames()]. Note that some methods
21+
#' have custom implementations (e.g., "glmnet2", "gbm2", "mars2", "ranger2",
22+
#' "rf2", "svm2", "mda2", "fda2") to ensure consistent parameterisation and
2323
#' performance across models.
2424
#' @param model_settings List or NULL. List of model-specific settings. If
2525
#' `NULL`, defaults to custom settings defined within the workflow.
@@ -230,9 +230,9 @@ fit_sdm_models <- function(
230230

231231
# rbf is not bounded; see https://github.com/babaknaimi/sdm/issues/42
232232
valid_sdm_methods <- c(
233-
"glm", "glmpoly", "gam", "glmnet", "glmnet2", "mars", "mars2", "gbm", "rf",
234-
"rf2", "ranger", "ranger2", "cart", "rpart", "maxent", "mlp", "svm",
235-
"svm2", "mda", "mda2", "fda", "fda2")
233+
"glm", "glmpoly", "gam", "glmnet", "glmnet2", "mars", "mars2", "gbm",
234+
"gbm2", "rf", "rf2", "ranger", "ranger2", "cart", "rpart", "maxent", "mlp",
235+
"svm", "svm2", "mda", "mda2", "fda", "fda2")
236236
sdm_method_valid <- any(
237237
is.null(sdm_method), length(sdm_method) != 1L,
238238
!is.character(sdm_method), !sdm_method %in% valid_sdm_methods)
@@ -344,6 +344,8 @@ fit_sdm_models <- function(
344344
"mars", "earth",
345345
"mars2", "earth",
346346
"gbm", "gbm",
347+
"gbm2", "gbm",
348+
"gbm2", "dismo",
347349
"rf", "randomForest",
348350
"rf2", "randomForest",
349351
"ranger", "ranger",

R/mod_ssdm_helpers.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,8 +1115,8 @@ prepare_input_data <- function(
11151115
paste(collapse = " + ") %>%
11161116
ecokit::cat_time(cat_timestamp = FALSE, level = 1, ... = "\n")
11171117

1118-
species_modelling_data <- species_modelling_data %>%
1119-
dplyr::filter(!species_name %in% excluded_species)
1118+
species_modelling_data <- dplyr::filter(
1119+
species_modelling_data, !species_name %in% excluded_species)
11201120

11211121
} else {
11221122
excluded_species <- NA_character_
@@ -1596,7 +1596,8 @@ sdm_model_settings <- function() {
15961596
glmnet2 = list(maxit = 100000L),
15971597
mars = list(),
15981598
mars2 = list(),
1599-
gbm = list(n.trees = 2000L, interaction.depth = 2L),
1599+
gbm = list(n.trees = 2000L, interaction.depth = 4L),
1600+
gbm2 = list(n.trees = 2000L, interaction.depth = 4L)),
16001601
rf = list(ntree = 1000L, nodesize = 5L),
16011602
rf2 = list(ntree = 1000L, nodesize = 5L),
16021603
ranger = list(

inst/brt2.R

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Author: Ahmed El-Gabbas Date: 2025-12-17 Licence GPL v3
2+
#
3+
# - gbm2 fits Boosted Regression Trees (BRTs) using `dismo::gbm.step`, which
4+
# selects an optimal number of trees during training based on a step-wise,
5+
# cross-validated procedure.
6+
# - Predictions are produced with `gbm::predict.gbm` at the best iteration
7+
# chosen by gbm.step (object$gbm.call$best.trees).
8+
#
9+
# The original gbm method in sdm package shows a low maximum predicted values
10+
11+
methodInfo <- list(
12+
name = c("gbm2", "GBM2", "brt2", "BRT2"),
13+
packages = c("dismo", "gbm"),
14+
modelTypes = c("pa", "pb", "ab", "n"),
15+
fitParams = list(
16+
formula = "standard.formula", data = "sdmDataFrame",
17+
v = "sdmVariables"),
18+
fitSettings = list(
19+
tree.complexity = 2, learning.rate = 0.01, bag.fraction = 0.5,
20+
n.folds = 10, max.trees = 10000, step.size = 50, keep.data = TRUE),
21+
22+
fitFunction = function(formula, data, v, ...) {
23+
fam <- switch(
24+
v@distribution,
25+
poisson = "poisson",
26+
multinomial = "multinomial",
27+
n = "gaussian",
28+
"bernoulli")
29+
30+
dismo::gbm.step(
31+
data = data,
32+
gbm.x = all.vars(formula)[-1],
33+
gbm.y = all.vars(formula)[1],
34+
family = fam,
35+
plot.main = FALSE,
36+
verbose = FALSE,
37+
silent = TRUE,
38+
...
39+
)
40+
},
41+
settingRules = NULL,
42+
tuneParams = NULL,
43+
predictParams = list(
44+
object = "model", formula = "standard.formula",
45+
newx = "sdmDataFrame", v = "sdmVariables"),
46+
predictSettings = list(type = "response"),
47+
predictFunction = function(object, formula, newx, v, type, ...) {
48+
gbm::predict.gbm(
49+
object = object,
50+
newdata = newx,
51+
n.trees = object$gbm.call$best.trees,
52+
type = type)
53+
}
54+
)

inst/svm2.R

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,67 @@
11
# Author: Ahmed El-Gabbas
2-
# Date: 2025-12-15
3-
# Based on: `babaknaimi/sdm/inst/methods/sdm/svm.R`
2+
# Date: 2025-12-18
43
#
5-
# - New method using e1071::svm backend with minimal changes.
6-
# - Enable probability outputs for classification by setting
7-
# `probability = TRUE` when not provided by the user. This matches original
8-
# svm's probability behavior.
9-
# - Predict function returns positive-class probability when available.
10-
# - All other arguments are passed via ... without overriding defaults.
4+
# - Binary classification for PA/PB using e1071::svm with probability outputs.
5+
# - Tuning: Parsimonious grid with 5-fold CV via e1071::tune:
6+
# kernel = "radial"
7+
# cost ∈ {1, 5, 10}
8+
# gamma ∈ {0.01, 0.05, 0.1}
9+
# Best model (tune.out$best.model) is returned for prediction.
10+
# - Class weights: Inverse-prevalence weighting (capped at 20) to handle
11+
# imbalance:
12+
# n0 = count of class "0", n1 = count of class "1"
13+
# weights = c("0" = 1, "1" = min(n0/n1, 20)) if n0 >= n1
14+
# c("0" = min(n1/n0, 20), "1" = 1) otherwise
15+
# - Prediction: predict(..., probability = TRUE); returns positive-class
16+
# probability:
17+
# * Assumes response encoded as 0/1 and present as the left-hand side of the formula.
18+
# * Predictors and their names in new data must match those used at training.
19+
# * Grid is intentionally small to remain fast and robust models
1120

1221
methodInfo <- list(
13-
name = c("svm2", "SVM2", "svm_e1071"),
22+
name = c("svm3", "SVM3", "svm_e1071_3"),
1423
packages = "e1071",
1524
modelTypes = c("pa", "pb", "ab", "n"),
1625
fitParams = list(
1726
formula = "standard.formula", data = "sdmDataFrame", v = "sdmVariables"),
18-
fitSettings = list(kernel = "radial", probability = TRUE),
19-
27+
fitSettings = list(kernel = "radial"),
2028
fitFunction = function(formula, data, v, ...) {
21-
x <- sdm:::.getData.sdmMatrix(
22-
formula, data, normalize = TRUE, frame = v@varInfo$numeric, scale = FALSE)
23-
y <- sdm:::.getData.sdmY(formula, data)
24-
25-
# class.weights is used to counter severe class imbalance in PA/PB data.
26-
#
27-
# In presence–absence SDMs, absences (0) often vastly outnumber presences
28-
# (1). Without weighting, the SVM’s loss is dominated by the majority class,
29-
# leading to poor discrimination (e.g., predicting almost all 0s).
3029

31-
# Compute counts of absences (n0) and presences (n1)
32-
n0 <- sum(y == 0, na.rm = TRUE)
33-
n1 <- sum(y == 1, na.rm = TRUE)
30+
formula <- as.formula(deparse(formula), env = environment())
31+
resp <- all.vars(formula)[1]
32+
data[, resp] <- factor(data[, resp], levels = c(0L, 1L))
3433

35-
# Upper bound for weight
34+
# Upweight the minority class
35+
n0 <- sum(data[, resp] == "0")
36+
n1 <- sum(data[, resp] == "1")
3637
max_weight <- 20
37-
38-
# - Upweight the minority class so its misclassification cost is comparable
39-
# to the majority class, using a simple inverse-prevalence rule:
40-
# minority_weight = n_majority / n_minority
41-
# - Cap the weight by max_weight (here 20) to avoid numeric instability and
42-
# overly aggressive rebalancing on extremely imbalanced folds.
43-
# - If absences are the majority (n0 >= n1), set weight for class "1"
44-
# (presence) to min(n0 / n1, max_weight), keep class "0" at 1. Otherwise,
45-
# upweight class "0" similarly when presences are the majority.
46-
# - Pass class.weights to e1071::svm so the optimization accounts for
47-
# imbalance while leaving all other defaults unchanged.
4838
if (n0 >= n1) {
49-
# More absences
50-
class.weights <- c("0" = 1, "1" = min(n0 / n1, max_weight))
39+
class.weights <- setNames(
40+
c(1, min(n0 / max(1, n1), max_weight)),
41+
c("0", "1"))
5142
} else {
52-
# More presences
53-
class.weights <- c("0" = min(n1 / n0, max_weight), "1" = 1)
43+
class.weights <- setNames(
44+
c(min(n1 / max(1, n0), max_weight), 1),
45+
c("0", "1"))
5446
}
5547

56-
e1071::svm(x = x, y = y, scale = TRUE, class.weights = class.weights, ...)
48+
tune.out <- e1071::tune(
49+
e1071::svm, train.x = formula, data = data, kernel = "radial",
50+
ranges = list(cost = c(1, 5, 10), gamma = c(0.01, 0.05, 0.1)),
51+
class.weights = class.weights, probability = TRUE,
52+
tunecontrol = e1071::tune.control(cross = 5))
53+
54+
tune.out$best.model
5755
},
5856
settingRules = NULL,
5957
tuneParams = NULL,
6058
predictParams = list(
6159
object = "model", formula = "standard.formula", newx = "sdmDataFrame",
6260
v = "sdmVariables"),
63-
predictSettings = list(probability = TRUE),
61+
predictSettings = list(),
6462
predictFunction = function(object, formula, newx, v, ...) {
65-
newx <- sdm:::.getData.sdmMatrix(
66-
formula, newx, normalize = TRUE,
67-
frame = v@varInfo$numeric, scale = FALSE)
68-
predict(object, newx, ...)
63+
pred_probs <- predict(
64+
object = object, newdata = newx, probability = TRUE, ...)
65+
attr(pred_probs, "probabilities")[, "1"]
6966
}
7067
)

0 commit comments

Comments
 (0)