Skip to content

Commit 8ba957d

Browse files
Fixed typo in Wald sampling
It gsc should have been sqrt(0.5*mean/scale), not 0.5*sqrt(mean/scale). Also made code transforming variables robust to significant digits cancellation.
1 parent f255c5f commit 8ba957d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

mkl_random/src/mkl_distributions.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ irk_wald_vec(irk_state *state, npy_intp len, double *res, const double mean, con
802802
int i, err;
803803
const double d_zero = 0., d_one = 1.0;
804804
double *uvec = NULL;
805-
double gsc = 0.5*sqrt(mean / scale);
805+
double gsc = sqrt(0.5*mean / scale);
806806

807807
if (len < 1)
808808
return;
@@ -817,15 +817,15 @@ irk_wald_vec(irk_state *state, npy_intp len, double *res, const double mean, con
817817
err = vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, state->stream, len, res, d_zero, gsc);
818818
assert(err == VSL_STATUS_OK);
819819

820-
/* Y = mean/(4 scale) * Z^2 */
820+
/* Y = mean/(2 scale) * Z^2 */
821821
vmdSqr(len, res, res, VML_HA);
822822

823823
DIST_PRAGMA_VECTOR
824824
for(i = 0; i < len; i++) {
825-
if(res[i] <= 1.0) {
826-
res[i] = 1.0 + res[i] - sqrt( res[i] * (res[i] + 2.0));
825+
if(res[i] <= 2.0) {
826+
res[i] = 1.0 + res[i] + sqrt(res[i] * (res[i] + 2.0));
827827
} else {
828-
res[i] = 1.0 - 2.0/(1.0 + sqrt( 1 + 2.0/res[i]));
828+
res[i] = 1.0 + res[i]*(1.0 + sqrt(1.0 + 2.0/res[i]));
829829
}
830830
}
831831

@@ -837,10 +837,10 @@ irk_wald_vec(irk_state *state, npy_intp len, double *res, const double mean, con
837837

838838
DIST_PRAGMA_VECTOR
839839
for(i=0; i<len; i++) {
840-
if (uvec[i]*(1.0 + res[i]) <= 1.0)
841-
res[i] = mean*res[i];
842-
else
840+
if (uvec[i]*(1.0 + res[i]) <= res[i])
843841
res[i] = mean/res[i];
842+
else
843+
res[i] = mean*res[i];
844844
}
845845

846846
mkl_free(uvec);

0 commit comments

Comments
 (0)