|
2 | 2 | from numpy.testing import assert_allclose |
3 | 3 | import mxnet as mx |
4 | 4 | from mxnet import gluon |
| 5 | +from scipy.stats import ks_2samp |
5 | 6 | import pytest |
6 | 7 | from gluonnlp.op import * |
7 | 8 | mx.npx.set_np() |
@@ -102,12 +103,36 @@ def test_gumbel_softmax(shape): |
102 | 103 | assume_allones = (ret == 1).sum(axis=-1).asnumpy() |
103 | 104 | assert_allclose(assume_allones, np.ones_like(assume_allones)) |
104 | 105 |
|
105 | | - |
| 106 | +@pytest.mark.parametrize('shape', (50,)) |
106 | 107 | @pytest.mark.seed(1) |
107 | | -def test_trunc_gumbel(): |
108 | | - # TODO(?) Improve the test case here |
109 | | - # It's generally difficult to test whether the samples are generated from a truncated gumbel |
110 | | - # distribution. Thus, we just verify that the samples are smaller than the provided threshold |
| 108 | +def test_trunc_gumbel(shape): |
| 109 | + # We first just verify that the samples are smaller than the provided threshold (i.e. they are truncated) |
| 110 | + # And also attempt to remove the truncation and verify if it is sampled from a gumbel distribution |
| 111 | + # using a KS-test with another sampled gumbel distribution |
| 112 | + |
| 113 | + # Verifying if the distribution is truncated |
111 | 114 | for i in range(1000): |
112 | | - samples = trunc_gumbel(mx.np.ones((10,)), 1.0).asnumpy() |
| 115 | + samples = trunc_gumbel(mx.np.ones(shape), 1.0).asnumpy() |
113 | 116 | assert (samples < 1.0).all() |
| 117 | + |
| 118 | + # perform ks-tests |
| 119 | + pvalues = [] |
| 120 | + for i in range(1000): |
| 121 | + logits = mx.np.random.uniform(-2, -1, shape) |
| 122 | + sampled_gumbels = mx.np.random.gumbel(mx.np.zeros_like(logits)) + logits # sample a gumbel distribution |
| 123 | + |
| 124 | + # sample a potential truncated gumbel distribution |
| 125 | + gumbels = mx.np.random.gumbel(mx.np.zeros_like(logits)) + logits |
| 126 | + sampled_truncated_gumbels = trunc_gumbel(logits, 0.5) |
| 127 | + |
| 128 | + # remove the truncation |
| 129 | + reconstructed_sample = -mx.np.log(mx.np.exp(-sampled_truncated_gumbels) - mx.np.exp(-0.5)) |
| 130 | + |
| 131 | + pvalue = ks_2samp(reconstructed_sample.asnumpy(), sampled_gumbels.asnumpy()).pvalue |
| 132 | + pvalues.append(pvalue) |
| 133 | + |
| 134 | + pvalues = np.array(pvalues) |
| 135 | + # Statistical inference condition: if out of all the tests, 90% of the resultant p-values > 0.05, |
| 136 | + # accept the null hypothesis (i.e. the reconstructed_samples indeed arrive from a gumbel distribution) |
| 137 | + assert (len(pvalues[pvalues > 0.05]) > 900) |
| 138 | + |
0 commit comments