|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import io |
16 | | -import itertools as it |
17 | 16 |
|
18 | 17 | from os import path |
19 | 18 |
|
20 | | -import cloudpickle |
21 | 19 | import numpy as np |
22 | 20 | import pytensor |
23 | 21 | import pytensor.tensor as pt |
|
29 | 27 | import pymc as pm |
30 | 28 |
|
31 | 29 | from pymc.data import MinibatchOp |
32 | | -from pymc.pytensorf import GeneratorOp, floatX |
| 30 | +from pymc.pytensorf import floatX |
33 | 31 |
|
34 | 32 |
|
35 | 33 | class TestData: |
@@ -495,83 +493,6 @@ def integers_ndim(ndim): |
495 | 493 | i += 1 |
496 | 494 |
|
497 | 495 |
|
498 | | -@pytest.mark.usefixtures("strict_float32") |
499 | | -class TestGenerator: |
500 | | - def test_basic(self): |
501 | | - generator = pm.GeneratorAdapter(integers()) |
502 | | - gop = GeneratorOp(generator)() |
503 | | - assert gop.tag.test_value == np.float32(0) |
504 | | - f = pytensor.function([], gop) |
505 | | - assert f() == np.float32(0) |
506 | | - assert f() == np.float32(1) |
507 | | - for _ in range(2, 100): |
508 | | - f() |
509 | | - assert f() == np.float32(100) |
510 | | - |
511 | | - def test_ndim(self): |
512 | | - for ndim in range(10): |
513 | | - res = list(it.islice(integers_ndim(ndim), 0, 2)) |
514 | | - generator = pm.GeneratorAdapter(integers_ndim(ndim)) |
515 | | - gop = GeneratorOp(generator)() |
516 | | - f = pytensor.function([], gop) |
517 | | - assert ndim == res[0].ndim |
518 | | - np.testing.assert_equal(f(), res[0]) |
519 | | - np.testing.assert_equal(f(), res[1]) |
520 | | - |
521 | | - def test_cloning_available(self): |
522 | | - gop = pm.generator(integers()) |
523 | | - res = gop**2 |
524 | | - shared = pytensor.shared(pm.floatX(10)) |
525 | | - res1 = pytensor.clone_replace(res, {gop: shared}) |
526 | | - f = pytensor.function([], res1) |
527 | | - assert f() == np.float32(100) |
528 | | - |
529 | | - def test_default_value(self): |
530 | | - def gen(): |
531 | | - for i in range(2): |
532 | | - yield pm.floatX(np.ones((10, 10)) * i) |
533 | | - |
534 | | - gop = pm.generator(gen(), np.ones((10, 10)) * 10) |
535 | | - f = pytensor.function([], gop) |
536 | | - np.testing.assert_equal(np.ones((10, 10)) * 0, f()) |
537 | | - np.testing.assert_equal(np.ones((10, 10)) * 1, f()) |
538 | | - np.testing.assert_equal(np.ones((10, 10)) * 10, f()) |
539 | | - with pytest.raises(ValueError): |
540 | | - gop.set_default(1) |
541 | | - |
542 | | - def test_set_gen_and_exc(self): |
543 | | - def gen(): |
544 | | - for i in range(2): |
545 | | - yield pm.floatX(np.ones((10, 10)) * i) |
546 | | - |
547 | | - gop = pm.generator(gen()) |
548 | | - f = pytensor.function([], gop) |
549 | | - np.testing.assert_equal(np.ones((10, 10)) * 0, f()) |
550 | | - np.testing.assert_equal(np.ones((10, 10)) * 1, f()) |
551 | | - with pytest.raises(StopIteration): |
552 | | - f() |
553 | | - gop.set_gen(gen()) |
554 | | - np.testing.assert_equal(np.ones((10, 10)) * 0, f()) |
555 | | - np.testing.assert_equal(np.ones((10, 10)) * 1, f()) |
556 | | - |
557 | | - def test_pickling(self, datagen): |
558 | | - gen = pm.generator(datagen) |
559 | | - cloudpickle.loads(cloudpickle.dumps(gen)) |
560 | | - bad_gen = pm.generator(integers()) |
561 | | - with pytest.raises(TypeError): |
562 | | - cloudpickle.dumps(bad_gen) |
563 | | - |
564 | | - def test_gen_cloning_with_shape_change(self, datagen): |
565 | | - gen = pm.generator(datagen) |
566 | | - gen_r = pt.random.normal(size=gen.shape).T |
567 | | - X = gen.dot(gen_r) |
568 | | - res, _ = pytensor.scan(lambda x: x.sum(), X, n_steps=X.shape[0]) |
569 | | - assert res.eval().shape == (50,) |
570 | | - shared = pytensor.shared(datagen.data.astype(gen.dtype)) |
571 | | - res2 = pytensor.clone_replace(res, {gen: shared**2}) |
572 | | - assert res2.eval().shape == (1000,) |
573 | | - |
574 | | - |
575 | 496 | def gen1(): |
576 | 497 | i = 0 |
577 | 498 | while True: |
|
0 commit comments