Skip to content

Commit 447c64b

Browse files
committed
reduce allocations and factor out common cholesky stuff
1 parent 81c6511 commit 447c64b

File tree

2 files changed

+151
-77
lines changed

2 files changed

+151
-77
lines changed

src/ggm_model.cpp

Lines changed: 142 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@ void GGMModel::get_constants(size_t i, size_t j) {
1212
// TODO: helper function?
1313
double logdet_omega = get_log_det(phi_);
1414

15-
double log_adj_omega_ii = logdet_omega + log(abs(inv_omega_(i, i)));
16-
double log_adj_omega_ij = logdet_omega + log(abs(inv_omega_(i, j)));
17-
double log_adj_omega_jj = logdet_omega + log(abs(inv_omega_(j, j)));
15+
double log_adj_omega_ii = logdet_omega + std::log(std::abs(inv_omega_(i, i)));
16+
double log_adj_omega_ij = logdet_omega + std::log(std::abs(inv_omega_(i, j)));
17+
double log_adj_omega_jj = logdet_omega + std::log(std::abs(inv_omega_(j, j)));
1818

1919
double inv_omega_sub_j1j1 = compute_inv_submatrix_i(inv_omega_, i, j, j);
20-
double log_abs_inv_omega_sub_jj = log_adj_omega_ii + log(abs(inv_omega_sub_j1j1));
21-
20+
double log_abs_inv_omega_sub_jj = log_adj_omega_ii + std::log(std::abs(inv_omega_sub_j1j1));
2221
double Phi_q1q = (2 * std::signbit(inv_omega_(i, j)) - 1) * std::exp(
2322
(log_adj_omega_ij - (log_adj_omega_jj + log_abs_inv_omega_sub_jj) / 2)
2423
);
@@ -191,49 +190,89 @@ void GGMModel::update_edge_parameter(size_t i, size_t j) {
191190
// accept proposal
192191
proposal_.increment_accepts(e);
193192

194-
double omega_ij = omega_(i, j);
195-
double omega_jj = omega_(j, j);
193+
double omega_ij_old = omega_(i, j);
194+
double omega_jj_old = omega_(j, j);
195+
196196

197197
omega_(i, j) = omega_prop_q1q;
198198
omega_(j, i) = omega_prop_q1q;
199199
omega_(j, j) = omega_prop_qq;
200200

201-
// TODO: preallocate?
202-
// find v for low rank update
203-
arma::vec v1 = {0, -1};
204-
arma::vec v2 = {omega_ij - omega_prop_(i, j), (omega_jj - omega_prop_(j, j)) / 2};
201+
cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j);
202+
203+
// // TODO: preallocate?
204+
// // find v for low rank update
205+
// arma::vec v1 = {0, -1};
206+
// arma::vec v2 = {omega_ij - omega_prop_(i, j), (omega_jj - omega_prop_(j, j)) / 2};
205207

206-
arma::vec vf1 = arma::zeros<arma::vec>(p_);
207-
arma::vec vf2 = arma::zeros<arma::vec>(p_);
208-
vf1[i] = v1[0];
209-
vf1[j] = v1[1];
210-
vf2[i] = v2[0];
211-
vf2[j] = v2[1];
208+
// arma::vec vf1 = arma::zeros<arma::vec>(p_);
209+
// arma::vec vf2 = arma::zeros<arma::vec>(p_);
210+
// vf1[i] = v1[0];
211+
// vf1[j] = v1[1];
212+
// vf2[i] = v2[0];
213+
// vf2[j] = v2[1];
212214

213-
// we now have
214-
// aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1))
215+
// // we now have
216+
// // aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1))
215217

216-
arma::vec u1 = (vf1 + vf2) / sqrt(2);
217-
arma::vec u2 = (vf1 - vf2) / sqrt(2);
218+
// arma::vec u1 = (vf1 + vf2) / sqrt(2);
219+
// arma::vec u2 = (vf1 - vf2) / sqrt(2);
218220

219-
// we now have
220-
// omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2))
221-
// and also
222-
// aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2)))
221+
// // we now have
222+
// // omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2))
223+
// // and also
224+
// // aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2)))
223225

224-
// update phi (2x O(p^2))
225-
cholesky_update(phi_, u1);
226-
cholesky_downdate(phi_, u2);
226+
// // update phi (2x O(p^2))
227+
// cholesky_update(phi_, u1);
228+
// cholesky_downdate(phi_, u2);
227229

228-
// update inverse (2x O(p^2))
229-
arma::inv(inv_phi_, arma::trimatu(phi_));
230-
inv_omega_ = inv_phi_ * inv_phi_.t();
230+
// // update inverse (2x O(p^2))
231+
// arma::inv(inv_phi_, arma::trimatu(phi_));
232+
// inv_omega_ = inv_phi_ * inv_phi_.t();
231233

232234
}
233235

234236
proposal_.update_proposal_sd(e);
235237
}
236238

239+
void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j)
240+
{
241+
242+
v2_[0] = omega_ij_old - omega_prop_(i, j);
243+
v2_[1] = (omega_jj_old - omega_prop_(j, j)) / 2;
244+
245+
vf1_[i] = v1_[0];
246+
vf1_[j] = v1_[1];
247+
vf2_[i] = v2_[0];
248+
vf2_[j] = v2_[1];
249+
250+
// we now have
251+
// aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1))
252+
253+
u1_ = (vf1_ + vf2_) / sqrt(2);
254+
u2_ = (vf1_ - vf2_) / sqrt(2);
255+
256+
// we now have
257+
// omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2))
258+
// and also
259+
// aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2)))
260+
261+
// update phi (2x O(p^2))
262+
cholesky_update(phi_, u1_);
263+
cholesky_downdate(phi_, u2_);
264+
265+
// update inverse (2x O(p^2))
266+
arma::inv(inv_phi_, arma::trimatu(phi_));
267+
inv_omega_ = inv_phi_ * inv_phi_.t();
268+
269+
// reset for next iteration
270+
vf1_[i] = 0.0;
271+
vf1_[j] = 0.0;
272+
vf2_[i] = 0.0;
273+
vf2_[j] = 0.0;
274+
275+
}
237276

238277
void GGMModel::update_diagonal_parameter(size_t i) {
239278
// Implementation of diagonal parameter update
@@ -280,29 +319,53 @@ void GGMModel::update_diagonal_parameter(size_t i) {
280319
proposal_.increment_accepts(e);
281320

282321
double omega_ii = omega_(i, i);
322+
omega_(i, i) = omega_prop_(i, i);
283323

284-
arma::vec u(p_, arma::fill::zeros);
285-
double delta = omega_ii - omega_prop_(i, i);
286-
bool s = delta > 0;
287-
u(i) = std::sqrt(std::abs(delta));
324+
cholesky_update_after_diag(omega_ii, i);
288325

289-
omega_(i, i) = omega_prop_(i, i);
326+
// arma::vec u(p_, arma::fill::zeros);
327+
// double delta = omega_ii - omega_prop_(i, i);
328+
// bool s = delta > 0;
329+
// u(i) = std::sqrt(std::abs(delta));
290330

291-
if (s)
292-
cholesky_downdate(phi_, u);
293-
else
294-
cholesky_update(phi_, u);
295331

296-
// update inverse (2x O(p^2))
297-
arma::inv(inv_phi_, arma::trimatu(phi_));
298-
inv_omega_ = inv_phi_ * inv_phi_.t();
332+
// if (s)
333+
// cholesky_downdate(phi_, u);
334+
// else
335+
// cholesky_update(phi_, u);
336+
337+
// // update inverse (2x O(p^2))
338+
// arma::inv(inv_phi_, arma::trimatu(phi_));
339+
// inv_omega_ = inv_phi_ * inv_phi_.t();
299340

300341

301342
}
302343

303344
proposal_.update_proposal_sd(e);
304345
}
305346

347+
void GGMModel::cholesky_update_after_diag(double omega_ii_old, size_t i)
348+
{
349+
350+
double delta = omega_ii_old - omega_prop_(i, i);
351+
352+
bool s = delta > 0;
353+
vf1_(i) = std::sqrt(std::abs(delta));
354+
355+
if (s)
356+
cholesky_downdate(phi_, vf1_);
357+
else
358+
cholesky_update(phi_, vf1_);
359+
360+
// update inverse (2x O(p^2))
361+
arma::inv(inv_phi_, arma::trimatu(phi_));
362+
inv_omega_ = inv_phi_ * inv_phi_.t();
363+
364+
// reset for next iteration
365+
vf1_(i) = 0.0;
366+
}
367+
368+
306369
void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) {
307370

308371
size_t e = i * (i + 1) / 2 + j; // parameter index in vectorized form
@@ -351,27 +414,28 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) {
351414
edge_indicators_(i, j) = 0;
352415
edge_indicators_(j, i) = 0;
353416

354-
// Cholesky update vectors
355-
arma::vec v1 = {0, -1};
356-
arma::vec v2 = {omega_ij_old - 0.0, (omega_jj_old - omega_(j, j)) / 2};
417+
cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j);
418+
// // Cholesky update vectors
419+
// arma::vec v1 = {0, -1};
420+
// arma::vec v2 = {omega_ij_old - 0.0, (omega_jj_old - omega_(j, j)) / 2};
357421

358-
arma::vec vf1 = arma::zeros<arma::vec>(p_);
359-
arma::vec vf2 = arma::zeros<arma::vec>(p_);
360-
vf1[i] = v1[0];
361-
vf1[j] = v1[1];
362-
vf2[i] = v2[0];
363-
vf2[j] = v2[1];
422+
// arma::vec vf1 = arma::zeros<arma::vec>(p_);
423+
// arma::vec vf2 = arma::zeros<arma::vec>(p_);
424+
// vf1[i] = v1[0];
425+
// vf1[j] = v1[1];
426+
// vf2[i] = v2[0];
427+
// vf2[j] = v2[1];
364428

365-
arma::vec u1 = (vf1 + vf2) / sqrt(2);
366-
arma::vec u2 = (vf1 - vf2) / sqrt(2);
429+
// arma::vec u1 = (vf1 + vf2) / sqrt(2);
430+
// arma::vec u2 = (vf1 - vf2) / sqrt(2);
367431

368-
// Update Cholesky factor
369-
cholesky_update(phi_, u1);
370-
cholesky_downdate(phi_, u2);
432+
// // Update Cholesky factor
433+
// cholesky_update(phi_, u1);
434+
// cholesky_downdate(phi_, u2);
371435

372-
// Update inverse
373-
arma::inv(inv_phi_, arma::trimatu(phi_));
374-
inv_omega_ = inv_phi_ * inv_phi_.t();
436+
// // Update inverse
437+
// arma::inv(inv_phi_, arma::trimatu(phi_));
438+
// inv_omega_ = inv_phi_ * inv_phi_.t();
375439
}
376440

377441
} else {
@@ -425,27 +489,28 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) {
425489
edge_indicators_(i, j) = 1;
426490
edge_indicators_(j, i) = 1;
427491

428-
// Cholesky update vectors
429-
arma::vec v1 = {0, -1};
430-
arma::vec v2 = {omega_ij_old - omega_(i, j), (omega_jj_old - omega_(j, j)) / 2};
492+
cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j);
493+
// // Cholesky update vectors
494+
// arma::vec v1 = {0, -1};
495+
// arma::vec v2 = {omega_ij_old - omega_(i, j), (omega_jj_old - omega_(j, j)) / 2};
431496

432-
arma::vec vf1 = arma::zeros<arma::vec>(p_);
433-
arma::vec vf2 = arma::zeros<arma::vec>(p_);
434-
vf1[i] = v1[0];
435-
vf1[j] = v1[1];
436-
vf2[i] = v2[0];
437-
vf2[j] = v2[1];
497+
// arma::vec vf1 = arma::zeros<arma::vec>(p_);
498+
// arma::vec vf2 = arma::zeros<arma::vec>(p_);
499+
// vf1[i] = v1[0];
500+
// vf1[j] = v1[1];
501+
// vf2[i] = v2[0];
502+
// vf2[j] = v2[1];
438503

439-
arma::vec u1 = (vf1 + vf2) / sqrt(2);
440-
arma::vec u2 = (vf1 - vf2) / sqrt(2);
504+
// arma::vec u1 = (vf1 + vf2) / sqrt(2);
505+
// arma::vec u2 = (vf1 - vf2) / sqrt(2);
441506

442-
// Update Cholesky factor
443-
cholesky_update(phi_, u1);
444-
cholesky_downdate(phi_, u2);
507+
// // Update Cholesky factor
508+
// cholesky_update(phi_, u1);
509+
// cholesky_downdate(phi_, u2);
445510

446-
// Update inverse
447-
arma::inv(inv_phi_, arma::trimatu(phi_));
448-
inv_omega_ = inv_phi_ * inv_phi_.t();
511+
// // Update inverse
512+
// arma::inv(inv_phi_, arma::trimatu(phi_));
513+
// inv_omega_ = inv_phi_ * inv_phi_.t();
449514
}
450515
}
451516
}

src/ggm_model.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ class GGMModel : public BaseModel {
133133
arma::mat omega_prop_;
134134
arma::vec constants_; // Phi_q1q, Phi_q1q1, c[1], c[2], c[3], c[4]
135135

136+
arma::vec v1_ = {0, -1};
137+
arma::vec v2_ = {0, 0};
138+
arma::vec vf1_ = arma::zeros<arma::vec>(p_);
139+
arma::vec vf2_ = arma::zeros<arma::vec>(p_);
140+
arma::vec u1_ = arma::zeros<arma::vec>(p_);
141+
arma::vec u2_ = arma::zeros<arma::vec>(p_);
142+
136143
// Parameter group updates with optimized likelihood evaluations
137144
void update_edge_parameter(size_t i, size_t j);
138145
void update_diagonal_parameter(size_t i);
@@ -147,6 +154,8 @@ class GGMModel : public BaseModel {
147154
double log_density_impl_edge(size_t i, size_t j) const;
148155
double log_density_impl_diag(size_t j) const;
149156
double get_log_det(arma::mat triangular_A) const;
157+
void cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j);
158+
void cholesky_update_after_diag(double omega_ii_old, size_t i);
150159
// double find_reasonable_step_size_edge(const arma::mat& omega, size_t i, size_t j);
151160
// double find_reasonable_step_size_diag(const arma::mat& omega, size_t i);
152161
// double edge_log_ratio(const arma::mat& omega, size_t i, size_t j, double proposal);

0 commit comments

Comments
 (0)