Skip to content

Commit a6c6d93

Browse files
committed
improve code performance by using the kernel's separability
1 parent 1f828e3 commit a6c6d93

File tree

1 file changed

+57
-35
lines changed

1 file changed

+57
-35
lines changed

src/colvargrid.h

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,56 +1604,78 @@ class colvar_grid_gradient : public colvar_grid<cvm::real>
16041604
cvm::real smoothing = 0) {
16051605

16061606
if (smoothing && weights->value(bin_value) < full_samples) {
1607-
if (smoothing < 0)
1608-
cvm::error("kernel parameter for kernel grid ABF is set inferior to 0", COLVARS_INPUT_ERROR);
1609-
cvm::real kernel_params = smoothing * (1 - std::max(0.,weights->value(bin_value)-min_samples) / (full_samples-min_samples)); // * weights->value(bin_value) / full_samples
1610-
cvm::real inv_squared_smooth = 1/ (std::max(kernel_params*kernel_params, 1e-5));
1611-
int cutoff = static_cast<int>(cvm::floor(cutoff_factor * kernel_params)); // take like floor()
1612-
for (size_t i = 0; i < nd; i++) {
1613-
cutoff = std::min(cutoff, nx[i]/2);
1614-
}
1615-
// We will use these to iterate
1616-
std::vector<int> ix_min(nd);
1617-
std::vector<int> ix_max(nd);
1618-
std::vector<int> current_ix(nd);
1619-
std::vector<int> wrapped_ix(nd);
1607+
if (smoothing < 0)
1608+
cvm::error("kernel parameter for kernel grid ABF is set inferior to 0", COLVARS_INPUT_ERROR);
1609+
1610+
cvm::real kernel_params = smoothing; // * (1 - std::max(0., weights->value(bin_value) - min_samples) / (full_samples - min_samples)); //TODO: uncomment
1611+
cvm::real inv_squared_smooth = 1.0 / (std::max(kernel_params * kernel_params, 1e-5));
1612+
int cutoff = static_cast<int>(cvm::floor(cutoff_factor * kernel_params));
1613+
1614+
for (size_t i = 0; i < nd; i++) {
1615+
cutoff = std::min(cutoff, nx[i] / 2);
1616+
}
1617+
1618+
// 1. Pre-calculate 1D weights and wrapped indices for each dimension
1619+
std::vector<std::vector<cvm::real>> w_1d(nd);
1620+
std::vector<std::vector<int>> idx_1d(nd);
1621+
cvm::real total_sum = 1.0;
16201622

16211623
for (size_t i = 0; i < nd; i++) {
1622-
// Calculate raw bounds (can be negative or > nx[i])
1623-
ix_min[i] = static_cast<int>(cvm::floor(cv_value[i] - cutoff));
1624-
ix_max[i] = static_cast<int>(cvm::floor(cv_value[i] + cutoff));
1624+
// can be negative or > nx[i] to allow for distance calculation
1625+
int i_min = static_cast<int>(std::floor(cv_value[i] - cutoff));
1626+
int i_max = static_cast<int>(std::floor(cv_value[i] + cutoff));
16251627

1626-
// If NOT periodic, clamp to grid boundaries
16271628
if (!periodic[i]) {
1628-
if (ix_min[i] < 0) ix_min[i] = 0;
1629-
if (ix_max[i] >= nx[i]) ix_max[i] = nx[i] - 1;
1629+
if (i_min < 0) i_min = 0;
1630+
if (i_max >= nx[i]) i_max = nx[i] - 1;
16301631
}
1631-
current_ix[i] = ix_min[i];
1632+
w_1d[i].resize(i_max - i_min + 1);
1633+
cvm::real dim_sum = 0.0;
1634+
int counter = 0;
1635+
for (int ix = i_min; ix <= i_max; ix++) {
1636+
// Calculate 1D Gaussian component
1637+
cvm::real diff = (static_cast<cvm::real>(ix) + 0.5) - cv_value[i];
1638+
cvm::real weight = cvm::exp(-diff * diff * inv_squared_smooth);
1639+
1640+
w_1d[i][counter++] = weight;
1641+
dim_sum += weight;
1642+
1643+
if (periodic[i]) {
1644+
// Safe modulo for negative numbers
1645+
idx_1d[i].push_back((ix % nx[i] + nx[i]) % nx[i]);
1646+
} else {
1647+
idx_1d[i].push_back(ix);
1648+
}
1649+
}
1650+
// The N-D sum is the product of the 1D sums
1651+
// total_sum *= dim_sum; //TODO : UNCOMMENT
16321652
}
16331653

1654+
cvm::real inv_total_sum = 1.0 / total_sum;
1655+
1656+
std::vector<int> current_ix(nd, 0);
1657+
std::vector<int> wrapped_ix(nd);
16341658
bool done = false;
1659+
16351660
while (!done) {
1636-
cvm::real dist_sq = 0;
1661+
cvm::real combined_weight = inv_total_sum;
16371662
for (size_t i = 0; i < nd; i++) {
1638-
cvm::real diff = (static_cast<cvm::real>(current_ix[i]) + 0.5) - cv_value[i];
1639-
dist_sq += diff * diff;
1640-
if (periodic[i]) {
1641-
wrapped_ix[i] = (current_ix[i] % (int)nx[i] + (int)nx[i]) % (int)nx[i];
1642-
} else {
1643-
wrapped_ix[i] = current_ix[i];
1644-
}
1663+
int local_pos = current_ix[i];
1664+
combined_weight *= w_1d[i][local_pos];
1665+
wrapped_ix[i] = idx_1d[i][local_pos];
16451666
}
1646-
cvm::real weight = cvm::exp(-dist_sq * inv_squared_smooth);
1647-
acc_force(wrapped_ix, force, weight);
16481667

1668+
acc_force(wrapped_ix, force, combined_weight);
1669+
1670+
// iterates through the kernel support
16491671
for (int i = nd - 1; i >= 0; i--) {
1650-
if (++current_ix[i] > ix_max[i]) {
1672+
if (++current_ix[i] >= static_cast<int>(w_1d[i].size())) {
16511673
if (i == 0) {
1652-
done = true;
1653-
break;
1674+
done = true;
1675+
break;
16541676
}
1655-
current_ix[i] = ix_min[i];
1656-
} else {
1677+
current_ix[i] = 0;
1678+
} else {
16571679
break;
16581680
}
16591681
}

0 commit comments

Comments
 (0)