Skip to content

Commit 2a46030

Browse files
authored
Update paths to data files (#1996)
* Update paths to data files * address numerical issues due to new jax release * fix the size of the new SP500 file
1 parent 627d19a commit 2a46030

File tree

5 files changed

+22
-15
lines changed

5 files changed

+22
-15
lines changed

numpyro/examples/datasets.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
dset = namedtuple("dset", ["name", "urls"])
2929

3030
BASEBALL = dset(
31-
"baseball", ["https://d2hg8soec8ck9v.cloudfront.net/datasets/EfronMorrisBB.txt"]
31+
"baseball",
32+
["https://github.com/pyro-ppl/datasets/blob/master/EfronMorrisBB.txt?raw=true"],
3233
)
3334

3435
BOSTON_HOUSING = dset(
@@ -37,7 +38,7 @@
3738
)
3839

3940
COVTYPE = dset(
40-
"covtype", ["https://d2hg8soec8ck9v.cloudfront.net/datasets/covtype.zip"]
41+
"covtype", ["https://github.com/pyro-ppl/datasets/blob/master/covtype.npz?raw=true"]
4142
)
4243

4344
DIPPER_VOLE = dset(
@@ -48,26 +49,32 @@
4849
MNIST = dset(
4950
"mnist",
5051
[
51-
"https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz",
52-
"https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-labels-idx1-ubyte.gz",
53-
"https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/t10k-images-idx3-ubyte.gz",
54-
"https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/t10k-labels-idx1-ubyte.gz",
52+
"https://github.com/pyro-ppl/datasets/blob/master/mnist/train-images-idx3-ubyte.gz?raw=true",
53+
"https://github.com/pyro-ppl/datasets/blob/master/mnist/train-labels-idx1-ubyte.gz?raw=true",
54+
"https://github.com/pyro-ppl/datasets/blob/master/mnist/t10k-images-idx3-ubyte.gz?raw=true",
55+
"https://github.com/pyro-ppl/datasets/blob/master/mnist/t10k-labels-idx1-ubyte.gz?raw=true",
5556
],
5657
)
5758

58-
SP500 = dset("SP500", ["https://d2hg8soec8ck9v.cloudfront.net/datasets/SP500.csv"])
59+
SP500 = dset(
60+
"SP500", ["https://github.com/pyro-ppl/datasets/blob/master/SP500.csv?raw=true"]
61+
)
5962

6063
UCBADMIT = dset(
61-
"ucbadmit", ["https://d2hg8soec8ck9v.cloudfront.net/datasets/UCBadmit.csv"]
64+
"ucbadmit",
65+
["https://github.com/pyro-ppl/datasets/blob/master/UCBadmit.csv?raw=true"],
6266
)
6367

6468
LYNXHARE = dset(
65-
"lynxhare", ["https://d2hg8soec8ck9v.cloudfront.net/datasets/LynxHare.txt"]
69+
"lynxhare",
70+
["https://github.com/pyro-ppl/datasets/blob/master/LynxHare.txt?raw=true"],
6671
)
6772

6873
JSB_CHORALES = dset(
6974
"jsb_chorales",
70-
["https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle"],
75+
[
76+
"https://github.com/pyro-ppl/datasets/blob/master/polyphonic/jsb_chorales.pickle?raw=true"
77+
],
7178
)
7279

7380
HIGGS = dset(
@@ -129,7 +136,7 @@ def _load_boston_housing():
129136
def _load_covtype():
130137
_download(COVTYPE)
131138

132-
file_path = os.path.join(DATA_DIR, "covtype.zip")
139+
file_path = os.path.join(DATA_DIR, "covtype.npz")
133140
data = np.load(file_path)
134141

135142
return {"train": (data["data"], data["target"])}

test/infer/test_autoguide.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def model(x, y):
405405
b = numpyro.sample("b", dist.Normal(0, 10).expand([3]).to_event())
406406
mu = a + b[0] * x + b[1] * x**2 + b[2] * x**3
407407
with numpyro.plate("N", len(x)):
408-
numpyro.sample("y", dist.Normal(mu, 0.001), obs=y)
408+
numpyro.sample("y", dist.Normal(mu, 0.00001), obs=y)
409409

410410
x = random.normal(random.PRNGKey(0), (3,))
411411
y = 1 + 2 * x + 3 * x**2 + 4 * x**3

test/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,7 @@ def test_entropy_categorical():
15861586
probs = _to_probs_multinom(logits)
15871587
sp_dist = osp.multinomial(1, probs)
15881588
for jax_dist in [dist.CategoricalLogits(logits), dist.CategoricalProbs(probs)]:
1589-
assert_allclose(jax_dist.entropy(), sp_dist.entropy())
1589+
assert_allclose(jax_dist.entropy(), sp_dist.entropy(), rtol=1e-6, atol=1e-6)
15901590

15911591

15921592
def test_mixture_log_prob():

test/test_example_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def mean_pixels(i, mean_pix):
4242
def test_sp500_data_load():
4343
_, fetch = load_dataset(SP500, split="train", shuffle=False)
4444
date, value = fetch()
45-
assert jnp.shape(date) == jnp.shape(date) == (2427,)
45+
assert jnp.shape(date) == jnp.shape(date) == (2517,)
4646

4747

4848
def test_jsb_chorales():

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def test_bijective_transforms(transform, shape):
322322
if isinstance(transform, less_stable_transforms):
323323
atol = 1e-2
324324
elif isinstance(transform, (L1BallTransform, RecursiveLinearTransform)):
325-
atol = 0.1
325+
atol = 0.2
326326
assert jnp.allclose(x1, x2, atol=atol)
327327

328328
log_abs_det_jacobian = transform.log_abs_det_jacobian(x1, y)

0 commit comments

Comments
 (0)