@@ -1604,56 +1604,78 @@ class colvar_grid_gradient : public colvar_grid<cvm::real>
16041604 cvm::real smoothing = 0 ) {
16051605
16061606 if (smoothing && weights->value (bin_value) < full_samples) {
1607- if (smoothing < 0 )
1608- cvm::error (" kernel parameter for kernel grid ABF is set inferior to 0" , COLVARS_INPUT_ERROR);
1609- cvm::real kernel_params = smoothing * (1 - std::max (0 .,weights->value (bin_value)-min_samples) / (full_samples-min_samples)); // * weights->value(bin_value) / full_samples
1610- cvm::real inv_squared_smooth = 1 / (std::max (kernel_params*kernel_params, 1e-5 ));
1611- int cutoff = static_cast <int >(cvm::floor (cutoff_factor * kernel_params)); // take like floor()
1612- for (size_t i = 0 ; i < nd; i++) {
1613- cutoff = std::min (cutoff, nx[i]/2 );
1614- }
1615- // We will use these to iterate
1616- std::vector<int > ix_min (nd);
1617- std::vector<int > ix_max (nd);
1618- std::vector<int > current_ix (nd);
1619- std::vector<int > wrapped_ix (nd);
1607+ if (smoothing < 0 )
1608+ cvm::error (" kernel parameter for kernel grid ABF is set inferior to 0" , COLVARS_INPUT_ERROR);
1609+
1610+ cvm::real kernel_params = smoothing; // * (1 - std::max(0., weights->value(bin_value) - min_samples) / (full_samples - min_samples)); //TODO: uncomment
1611+ cvm::real inv_squared_smooth = 1.0 / (std::max (kernel_params * kernel_params, 1e-5 ));
1612+ int cutoff = static_cast <int >(cvm::floor (cutoff_factor * kernel_params));
1613+
1614+ for (size_t i = 0 ; i < nd; i++) {
1615+ cutoff = std::min (cutoff, nx[i] / 2 );
1616+ }
1617+
1618+ // 1. Pre-calculate 1D weights and wrapped indices for each dimension
1619+ std::vector<std::vector<cvm::real>> w_1d (nd);
1620+ std::vector<std::vector<int >> idx_1d (nd);
1621+ cvm::real total_sum = 1.0 ;
16201622
16211623 for (size_t i = 0 ; i < nd; i++) {
1622- // Calculate raw bounds ( can be negative or > nx[i])
1623- ix_min[i] = static_cast <int >(cvm ::floor (cv_value[i] - cutoff));
1624- ix_max[i] = static_cast <int >(cvm ::floor (cv_value[i] + cutoff));
1624+ // can be negative or > nx[i] to allow for distance calculation
1625+ int i_min = static_cast <int >(std ::floor (cv_value[i] - cutoff));
1626+ int i_max = static_cast <int >(std ::floor (cv_value[i] + cutoff));
16251627
1626- // If NOT periodic, clamp to grid boundaries
16271628 if (!periodic[i]) {
1628- if (ix_min[i] < 0 ) ix_min[i] = 0 ;
1629- if (ix_max[i] >= nx[i]) ix_max[i] = nx[i] - 1 ;
1629+ if (i_min < 0 ) i_min = 0 ;
1630+ if (i_max >= nx[i]) i_max = nx[i] - 1 ;
16301631 }
1631- current_ix[i] = ix_min[i];
1632+ w_1d[i].resize (i_max - i_min + 1 );
1633+ cvm::real dim_sum = 0.0 ;
1634+ int counter = 0 ;
1635+ for (int ix = i_min; ix <= i_max; ix++) {
1636+ // Calculate 1D Gaussian component
1637+ cvm::real diff = (static_cast <cvm::real>(ix) + 0.5 ) - cv_value[i];
1638+ cvm::real weight = cvm::exp (-diff * diff * inv_squared_smooth);
1639+
1640+ w_1d[i][counter++] = weight;
1641+ dim_sum += weight;
1642+
1643+ if (periodic[i]) {
1644+ // Safe modulo for negative numbers
1645+ idx_1d[i].push_back ((ix % nx[i] + nx[i]) % nx[i]);
1646+ } else {
1647+ idx_1d[i].push_back (ix);
1648+ }
1649+ }
1650+ // The N-D sum is the product of the 1D sums
1651+ // total_sum *= dim_sum; //TODO : UNCOMMENT
16321652 }
16331653
1654+ cvm::real inv_total_sum = 1.0 / total_sum;
1655+
1656+ std::vector<int > current_ix (nd, 0 );
1657+ std::vector<int > wrapped_ix (nd);
16341658 bool done = false ;
1659+
16351660 while (!done) {
1636- cvm::real dist_sq = 0 ;
1661+ cvm::real combined_weight = inv_total_sum ;
16371662 for (size_t i = 0 ; i < nd; i++) {
1638- cvm::real diff = (static_cast <cvm::real>(current_ix[i]) + 0.5 ) - cv_value[i];
1639- dist_sq += diff * diff;
1640- if (periodic[i]) {
1641- wrapped_ix[i] = (current_ix[i] % (int )nx[i] + (int )nx[i]) % (int )nx[i];
1642- } else {
1643- wrapped_ix[i] = current_ix[i];
1644- }
1663+ int local_pos = current_ix[i];
1664+ combined_weight *= w_1d[i][local_pos];
1665+ wrapped_ix[i] = idx_1d[i][local_pos];
16451666 }
1646- cvm::real weight = cvm::exp (-dist_sq * inv_squared_smooth);
1647- acc_force (wrapped_ix, force, weight);
16481667
1668+ acc_force (wrapped_ix, force, combined_weight);
1669+
1670+ // iterates through the kernel support
16491671 for (int i = nd - 1 ; i >= 0 ; i--) {
1650- if (++current_ix[i] > ix_max [i]) {
1672+ if (++current_ix[i] >= static_cast < int >(w_1d [i]. size ()) ) {
16511673 if (i == 0 ) {
1652- done = true ;
1653- break ;
1674+ done = true ;
1675+ break ;
16541676 }
1655- current_ix[i] = ix_min[i] ;
1656- } else {
1677+ current_ix[i] = 0 ;
1678+ } else {
16571679 break ;
16581680 }
16591681 }
0 commit comments