Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 83780ab

Browse files
authored
[TEST] Add statistical inference for Truncated Gumbel distribution (#1578)
* Add statistical inference for Truncated Gumbel distribution * Fix mxnet issue with test
1 parent f5a9fc1 commit 83780ab

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

tests/test_op.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from numpy.testing import assert_allclose
33
import mxnet as mx
44
from mxnet import gluon
5+
from scipy.stats import ks_2samp
56
import pytest
67
from gluonnlp.op import *
78
mx.npx.set_np()
@@ -102,12 +103,36 @@ def test_gumbel_softmax(shape):
102103
assume_allones = (ret == 1).sum(axis=-1).asnumpy()
103104
assert_allclose(assume_allones, np.ones_like(assume_allones))
104105

105-
106+
@pytest.mark.parametrize('shape', (50,))
106107
@pytest.mark.seed(1)
107-
def test_trunc_gumbel():
108-
# TODO(?) Improve the test case here
109-
# It's generally difficult to test whether the samples are generated from a truncated gumbel
110-
# distribution. Thus, we just verify that the samples are smaller than the provided threshold
108+
def test_trunc_gumbel(shape):
109+
# We first just verify that the samples are smaller than the provided threshold (i.e. they are truncated)
110+
# And also attempt to remove the truncation and verify if it is sampled from a gumbel distribution
111+
# using a KS-test with another sampled gumbel distribution
112+
113+
# Verifying if the distribution is truncated
111114
for i in range(1000):
112-
samples = trunc_gumbel(mx.np.ones((10,)), 1.0).asnumpy()
115+
samples = trunc_gumbel(mx.np.ones(shape), 1.0).asnumpy()
113116
assert (samples < 1.0).all()
117+
118+
# perform ks-tests
119+
pvalues = []
120+
for i in range(1000):
121+
logits = mx.np.random.uniform(-2, -1, shape)
122+
sampled_gumbels = mx.np.random.gumbel(mx.np.zeros_like(logits)) + logits # sample a gumbel distribution
123+
124+
# sample a potential truncated gumbel distribution
125+
gumbels = mx.np.random.gumbel(mx.np.zeros_like(logits)) + logits
126+
sampled_truncated_gumbels = trunc_gumbel(logits, 0.5)
127+
128+
# remove the truncation
129+
reconstructed_sample = -mx.np.log(mx.np.exp(-sampled_truncated_gumbels) - mx.np.exp(-0.5))
130+
131+
pvalue = ks_2samp(reconstructed_sample.asnumpy(), sampled_gumbels.asnumpy()).pvalue
132+
pvalues.append(pvalue)
133+
134+
pvalues = np.array(pvalues)
135+
# Statistical inference condition: if out of all the tests, 90% of the resultant p-values > 0.05,
136+
# accept the null hypothesis (i.e. the reconstructed_samples indeed arrive from a gumbel distribution)
137+
assert (len(pvalues[pvalues > 0.05]) > 900)
138+

0 commit comments

Comments
 (0)