Skip to content

Commit 7f74ab0

Browse files
committed
R interface for raw data and sufficient statistics
1 parent 447c64b commit 7f74ab0

File tree

5 files changed

+74
-8
lines changed

5 files changed

+74
-8
lines changed

R/RcppExports.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interact
3333
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter)
3434
}
3535

36-
sample_ggm <- function(X, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) {
37-
.Call(`_bgms_sample_ggm`, X, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type)
36+
sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) {
37+
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type)
3838
}
3939

4040
compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) {

src/RcppExports.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,12 @@ BEGIN_RCPP
182182
END_RCPP
183183
}
184184
// sample_ggm
185-
Rcpp::List sample_ggm(const arma::mat& X, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type);
186-
RcppExport SEXP _bgms_sample_ggm(SEXP XSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP) {
185+
Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type);
186+
RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP) {
187187
BEGIN_RCPP
188188
Rcpp::RObject rcpp_result_gen;
189189
Rcpp::RNGScope rcpp_rngScope_gen;
190-
Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP);
190+
Rcpp::traits::input_parameter< const Rcpp::List& >::type inputFromR(inputFromRSEXP);
191191
Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP);
192192
Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP);
193193
Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP);
@@ -197,7 +197,7 @@ BEGIN_RCPP
197197
Rcpp::traits::input_parameter< const int >::type seed(seedSEXP);
198198
Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP);
199199
Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP);
200-
rcpp_result_gen = Rcpp::wrap(sample_ggm(X, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type));
200+
rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type));
201201
return rcpp_result_gen;
202202
END_RCPP
203203
}

src/ggm_model.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,3 +610,35 @@ void GGMModel::do_one_mh_step() {
610610
// could also be called in the main MCMC loop
611611
proposal_.increment_iteration();
612612
}
613+
614+
615+
GGMModel createGGMFromR(
616+
const Rcpp::List& inputFromR,
617+
const arma::mat& prior_inclusion_prob,
618+
const arma::imat& initial_edge_indicators,
619+
const bool edge_selection
620+
) {
621+
622+
if (inputFromR.containsElementNamed("n") && inputFromR.containsElementNamed("suf_stat")) {
623+
int n = Rcpp::as<int>(inputFromR["n"]);
624+
arma::mat suf_stat = Rcpp::as<arma::mat>(inputFromR["suf_stat"]);
625+
return GGMModel(
626+
n,
627+
suf_stat,
628+
prior_inclusion_prob,
629+
initial_edge_indicators,
630+
edge_selection
631+
);
632+
} else if (inputFromR.containsElementNamed("X")) {
633+
arma::mat X = Rcpp::as<arma::mat>(inputFromR["X"]);
634+
return GGMModel(
635+
X,
636+
prior_inclusion_prob,
637+
initial_edge_indicators,
638+
edge_selection
639+
);
640+
} else {
641+
throw std::invalid_argument("Input list must contain either 'X' or both 'n' and 'suf_stat'.");
642+
}
643+
644+
}

src/ggm_model.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "adaptiveMetropolis.h"
66
#include "rng_utils.h"
77

8+
89
class GGMModel : public BaseModel {
910
public:
1011

@@ -31,6 +32,30 @@ class GGMModel : public BaseModel {
3132
constants_(6)
3233
{}
3334

35+
GGMModel(
36+
const int n,
37+
const arma::mat& suf_stat,
38+
const arma::mat& prior_inclusion_prob,
39+
const arma::imat& initial_edge_indicators,
40+
const bool edge_selection = true
41+
) : n_(n),
42+
p_(suf_stat.n_cols),
43+
dim_((p_ * (p_ + 1)) / 2),
44+
suf_stat_(suf_stat),
45+
prior_inclusion_prob_(prior_inclusion_prob),
46+
edge_selection_(edge_selection),
47+
proposal_(AdaptiveProposal(dim_, 500)),
48+
omega_(arma::eye<arma::mat>(p_, p_)),
49+
phi_(arma::eye<arma::mat>(p_, p_)),
50+
inv_phi_(arma::eye<arma::mat>(p_, p_)),
51+
inv_omega_(arma::eye<arma::mat>(p_, p_)),
52+
edge_indicators_(initial_edge_indicators),
53+
vectorized_parameters_(dim_),
54+
vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0),
55+
omega_prop_(arma::mat(p_, p_, arma::fill::none)),
56+
constants_(6)
57+
{}
58+
3459
GGMModel(const GGMModel& other)
3560
: BaseModel(other),
3661
dim_(other.dim_),
@@ -161,3 +186,11 @@ class GGMModel : public BaseModel {
161186
// double edge_log_ratio(const arma::mat& omega, size_t i, size_t j, double proposal);
162187
// double diag_log_ratio(const arma::mat& omega, size_t i, double proposal);
163188
};
189+
190+
191+
GGMModel createGGMFromR(
192+
const Rcpp::List& inputFromR,
193+
const arma::mat& prior_inclusion_prob,
194+
const arma::imat& initial_edge_indicators,
195+
const bool edge_selection = true
196+
);

src/sample_ggm.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Rcpp::List convert_sampler_output_to_ggm_result(const std::vector<ChainResultNew
159159

160160
// [[Rcpp::export]]
161161
Rcpp::List sample_ggm(
162-
const arma::mat& X,
162+
const Rcpp::List& inputFromR,
163163
const arma::mat& prior_inclusion_prob,
164164
const arma::imat& initial_edge_indicators,
165165
const int no_iter,
@@ -173,7 +173,8 @@ Rcpp::List sample_ggm(
173173

174174
// should be done dynamically
175175
// also adaptation method should be specified differently
176-
GGMModel model(X, prior_inclusion_prob, initial_edge_indicators, edge_selection);
176+
// GGMModel model(X, prior_inclusion_prob, initial_edge_indicators, edge_selection);
177+
GGMModel model = createGGMFromR(inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection);
177178

178179
ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type);
179180

0 commit comments

Comments
 (0)