Skip to content

Commit 5880403

Browse files
committed
remove the generator classes
1 parent e0e7511 commit 5880403

File tree

2 files changed

+2
-156
lines changed

2 files changed

+2
-156
lines changed

pymc/data.py

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,16 @@
3333
from pytensor.scalar import Cast
3434
from pytensor.tensor.elemwise import Elemwise
3535
from pytensor.tensor.random.basic import IntegersRV
36-
from pytensor.tensor.type import TensorType
3736
from pytensor.tensor.variable import TensorConstant, TensorVariable
3837

3938
import pymc as pm
4039

41-
from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX
40+
from pymc.pytensorf import convert_data
4241
from pymc.vartypes import isgenerator
4342

4443
__all__ = [
4544
"ConstantData",
4645
"Data",
47-
"GeneratorAdapter",
4846
"Minibatch",
4947
"MutableData",
5048
"get_data",
@@ -86,51 +84,6 @@ def clone(self):
8684
return cp
8785

8886

89-
class GeneratorAdapter:
90-
"""Class that helps infer data type of generator.
91-
92-
It looks at the first item, preserving the order of the resulting generator.
93-
"""
94-
95-
def make_variable(self, gop, name=None):
96-
var = GenTensorVariable(gop, self.tensortype, name)
97-
var.tag.test_value = self.test_value
98-
return var
99-
100-
def __init__(self, generator):
101-
if not pm.vartypes.isgenerator(generator):
102-
raise TypeError("Object should be generator like")
103-
self.test_value = smarttypeX(copy(next(generator)))
104-
# make pickling potentially possible
105-
self._yielded_test_value = False
106-
self.gen = generator
107-
self.tensortype = TensorType(self.test_value.dtype, ((False,) * self.test_value.ndim))
108-
109-
# python3 generator
110-
def __next__(self):
111-
"""Next value in the generator."""
112-
if not self._yielded_test_value:
113-
self._yielded_test_value = True
114-
return self.test_value
115-
else:
116-
return smarttypeX(copy(next(self.gen)))
117-
118-
# python2 generator
119-
next = __next__
120-
121-
def __iter__(self):
122-
"""Return an iterator."""
123-
return self
124-
125-
def __eq__(self, other):
126-
"""Return true if both objects are actually the same."""
127-
return id(self) == id(other)
128-
129-
def __hash__(self):
130-
"""Return a hash of the object."""
131-
return hash(id(self))
132-
133-
13487
class MinibatchIndexRV(IntegersRV):
13588
_print_name = ("minibatch_index", r"\operatorname{minibatch\_index}")
13689

@@ -170,8 +123,6 @@ def is_valid_observed(v) -> bool:
170123
isinstance(v.owner.op, MinibatchOp)
171124
and all(is_valid_observed(inp) for inp in v.owner.inputs)
172125
)
173-
# Or Generator
174-
or isinstance(v.owner.op, GeneratorOp)
175126
)
176127

177128

pymc/pytensorf.py

Lines changed: 1 addition & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
walk,
3737
)
3838
from pytensor.graph.fg import FunctionGraph, Output
39-
from pytensor.graph.op import Op
4039
from pytensor.scalar.basic import Cast
4140
from pytensor.scan.op import Scan
4241
from pytensor.tensor.basic import _as_tensor_variable
@@ -63,10 +62,8 @@
6362
"compile_pymc",
6463
"cont_inputs",
6564
"convert_data",
66-
"convert_generator_data",
6765
"convert_observed_data",
6866
"floatX",
69-
"generator",
7067
"gradient",
7168
"hessian",
7269
"hessian_diag",
@@ -81,20 +78,10 @@
8178
def convert_observed_data(data) -> np.ndarray | Variable:
8279
"""Convert user provided dataset to accepted formats."""
8380
if isgenerator(data):
84-
return convert_generator_data(data)
81+
raise TypeError("Data passed to `observed` cannot be a generator.")
8582
return convert_data(data)
8683

8784

88-
def convert_generator_data(data) -> TensorVariable:
89-
warnings.warn(
90-
"Generator data is deprecated and we intend to remove it."
91-
" If you disagree and need this, please get in touch via https://github.com/pymc-devs/pymc/issues.",
92-
DeprecationWarning,
93-
stacklevel=2,
94-
)
95-
return generator(data)
96-
97-
9885
def convert_data(data) -> np.ndarray | Variable:
9986
ret: np.ndarray | Variable
10087
if hasattr(data, "to_numpy") and hasattr(data, "isnull"):
@@ -625,98 +612,6 @@ def __call__(self, input):
625612
return pytensor.clone_replace(self.tensor, {oldinput: input}, rebuild_strict=False)
626613

627614

628-
class GeneratorOp(Op):
629-
"""
630-
Generator Op is designed for storing python generators inside pytensor graph.
631-
632-
__call__ creates TensorVariable
633-
It has 2 new methods
634-
- var.set_gen(gen): sets new generator
635-
- var.set_default(value): sets new default value (None erases default value)
636-
637-
If generator is exhausted, variable will produce default value if it is not None,
638-
else raises `StopIteration` exception that can be caught on runtime.
639-
640-
Parameters
641-
----------
642-
gen: generator that implements __next__ (py3) or next (py2) method
643-
and yields np.arrays with same types
644-
default: np.array with the same type as generator produces
645-
"""
646-
647-
__props__ = ("generator",)
648-
649-
def __init__(self, gen, default=None):
650-
warnings.warn(
651-
"generator data is deprecated and will be removed in a future release", FutureWarning
652-
)
653-
from pymc.data import GeneratorAdapter
654-
655-
super().__init__()
656-
if not isinstance(gen, GeneratorAdapter):
657-
gen = GeneratorAdapter(gen)
658-
self.generator = gen
659-
self.set_default(default)
660-
661-
def make_node(self, *inputs):
662-
gen_var = self.generator.make_variable(self)
663-
return Apply(self, [], [gen_var])
664-
665-
def perform(self, node, inputs, output_storage, params=None):
666-
if self.default is not None:
667-
output_storage[0][0] = next(self.generator, self.default)
668-
else:
669-
output_storage[0][0] = next(self.generator)
670-
671-
def do_constant_folding(self, fgraph, node):
672-
return False
673-
674-
__call__ = pytensor.config.change_flags(compute_test_value="off")(Op.__call__)
675-
676-
def set_gen(self, gen):
677-
from pymc.data import GeneratorAdapter
678-
679-
if not isinstance(gen, GeneratorAdapter):
680-
gen = GeneratorAdapter(gen)
681-
if not gen.tensortype == self.generator.tensortype:
682-
raise ValueError("New generator should yield the same type")
683-
self.generator = gen
684-
685-
def set_default(self, value):
686-
if value is None:
687-
self.default = None
688-
else:
689-
value = np.asarray(value, self.generator.tensortype.dtype)
690-
t1 = (False,) * value.ndim
691-
t2 = self.generator.tensortype.broadcastable
692-
if not t1 == t2:
693-
raise ValueError("Default value should have the same type as generator")
694-
self.default = value
695-
696-
697-
def generator(gen, default=None):
698-
"""
699-
Create a generator variable with possibility to set default value and new generator.
700-
701-
If generator is exhausted variable will produce default value if it is not None,
702-
else raises `StopIteration` exception that can be caught on runtime.
703-
704-
Parameters
705-
----------
706-
gen: generator that implements __next__ (py3) or next (py2) method
707-
and yields np.arrays with same types
708-
default: np.array with the same type as generator produces
709-
710-
Returns
711-
-------
712-
TensorVariable
713-
It has 2 new methods
714-
- var.set_gen(gen): sets new generator
715-
- var.set_default(value): sets new default value (None erases default value)
716-
"""
717-
return GeneratorOp(gen, default)()
718-
719-
720615
def ix_(*args):
721616
"""
722617
PyTensor np.ix_ analog.

0 commit comments

Comments
 (0)