Skip to content

Commit c90f601

Browse files
committed
Catch R errors and reject transition, closes #22
1 parent b09f7c5 commit c90f601

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

R/util.R

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,16 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
100100
call. = FALSE)
101101
}
102102

103-
fn1 <- function(v) { do.call(fn, c(list(v), extra_args_list)) }
103+
fn1 <- function(v) {
104+
# Catch errors in user function and return -Inf with message as attribute
105+
tryCatch({
106+
do.call(fn, c(list(v), extra_args_list))
107+
}, error = function(e) {
108+
res <- -Inf
109+
attr(res, "message") <- e$message
110+
res
111+
})
112+
}
104113
for (chain in seq_len(num_chains)) {
105114
validate_function(fn1, inits[[chain]], extra_args_list, grad = FALSE)
106115
}
@@ -121,7 +130,16 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
121130
fun_packages <- c(gp$packages, packages)
122131
}
123132
if (!is.null(grad_fun)) {
124-
gr1 <- function(v) { do.call(grad_fun, c(list(v), extra_args_list)) }
133+
gr1 <- function(v) {
134+
# Catch errors in user function and return -Inf with message as attribute
135+
tryCatch({
136+
do.call(grad_fun, c(list(v), extra_args_list))
137+
}, error = function(e) {
138+
res <- rep(-Inf, length(v))
139+
attr(res, "message") <- e$message
140+
res
141+
})
142+
}
125143
for (chain in seq_len(num_chains)) {
126144
validate_function(gr1, inits[[chain]], extra_args_list, grad = TRUE)
127145
}

src/include/estimator/estimator_ext_header.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@ double r_function(const T& v,
4343
std::ostream* pstream__) {
4444
double lp = 0;
4545
auto v_cons = stan::math::lub_constrain<jacobian__>(v, lower_bounds, upper_bounds, lp);
46-
return Rcpp::as<double>(internal::ll_fun(v_cons)) + lp;
46+
SEXP res = internal::ll_fun(v_cons);
47+
// If the result has a "message" attribute, it indicates an error in the user function
48+
if (Rcpp::RObject(res).hasAttribute("message")) {
49+
std::string msg = Rcpp::as<std::string>(Rcpp::RObject(res).attr("message"));
50+
throw std::domain_error("Error in user-defined function: " + msg);
51+
}
52+
return Rcpp::as<double>(res) + lp;
4753
}
4854

4955
template <bool jacobian__, typename T, stan::require_st_var<T>* = nullptr>
@@ -69,7 +75,13 @@ stan::math::var r_function(const T& v,
6975
rtn = funwrap(v.val());
7076
} else {
7177
arena_v = stan::math::lub_constrain<jacobian__>(v, lower_bounds, upper_bounds, lp);
72-
arena_grad = Rcpp::as<Eigen::VectorXd>(internal::grad_fun(arena_v.val()));
78+
SEXP res = internal::grad_fun(arena_v.val());
79+
// If the result has a "message" attribute, it indicates an error in the user function
80+
if (Rcpp::RObject(res).hasAttribute("message")) {
81+
std::string msg = Rcpp::as<std::string>(Rcpp::RObject(res).attr("message"));
82+
throw std::domain_error("Error in user-defined gradient function: " + msg);
83+
}
84+
arena_grad = Rcpp::as<Eigen::VectorXd>(res);
7385
rtn = Rcpp::as<double>(internal::ll_fun(arena_v.val()));
7486
}
7587
return make_callback_var(

0 commit comments

Comments
 (0)