38
38
from pymc .exceptions import NotConstantValueError
39
39
from pymc .logprob .utils import ParameterValueError
40
40
from pymc .pytensorf import (
41
+ GeneratorOp ,
41
42
collect_default_updates ,
42
43
compile_pymc ,
43
44
constant_fold ,
44
- convert_observed_data ,
45
+ convert_data ,
46
+ convert_generator_data ,
45
47
extract_obs_data ,
46
48
hessian ,
47
49
hessian_diag ,
@@ -188,9 +190,9 @@ def test_extract_obs_data():
188
190
189
191
190
192
@pytest .mark .parametrize ("input_dtype" , ["int32" , "int64" , "float32" , "float64" ])
191
- def test_convert_observed_data (input_dtype ):
193
+ def test_convert_data (input_dtype ):
192
194
"""
193
- Ensure that convert_observed_data returns the dense array, masked array,
195
+ Ensure that convert_data returns the dense array, masked array,
194
196
graph variable, TensorVariable, or sparse matrix as appropriate.
195
197
"""
196
198
# Create the various inputs to the function
@@ -206,12 +208,8 @@ def test_convert_observed_data(input_dtype):
206
208
missing_pandas_input = pd .DataFrame (missing_numpy_input )
207
209
masked_array_input = ma .array (dense_input , mask = (np .mod (dense_input , 2 ) == 0 ))
208
210
209
- # Create a generator object. Apparently the generator object needs to
210
- # yield numpy arrays.
211
- square_generator = (np .array ([i ** 2 ], dtype = int ) for i in range (100 ))
212
-
213
211
# Alias the function to be tested
214
- func = convert_observed_data
212
+ func = convert_data
215
213
216
214
#####
217
215
# Perform the various tests
@@ -255,21 +253,35 @@ def test_convert_observed_data(input_dtype):
255
253
else :
256
254
assert pytensor_output .dtype == intX
257
255
258
- # Check function behavior with generator data
259
- generator_output = func (square_generator )
260
256
261
- # Output is wrapped with `pm.floatX`, and this unwraps
262
- wrapped = generator_output .owner .inputs [0 ]
263
- # Make sure the returned object has .set_gen and .set_default methods
264
- assert hasattr (wrapped , "set_gen" )
265
- assert hasattr (wrapped , "set_default" )
257
+ @pytest .mark .parametrize ("input_dtype" , ["int32" , "int64" , "float32" , "float64" ])
258
+ def test_convert_generator_data (input_dtype ):
259
+ # Create a generator object producing NumPy arrays with the intended dtype.
260
+ # This is required to infer the correct dtype.
261
+ square_generator = (np .array ([i ** 2 ], dtype = input_dtype ) for i in range (100 ))
262
+
263
+ # Output is NOT wrapped with `pm.floatX`/`intX`,
264
+ # but produced from calling a special Op.
265
+ result = convert_generator_data (square_generator )
266
+ apply = result .owner
267
+ op = apply .op
266
268
# Make sure the returned object is an PyTensor TensorVariable
267
- assert isinstance (wrapped , TensorVariable )
269
+ assert isinstance (result , TensorVariable )
270
+ assert isinstance (op , GeneratorOp ), f"It's a { type (apply )} "
271
+ # There are no inputs - because it generates...
272
+ assert apply .inputs == []
273
+
274
+ # Evaluation results should have the correct* dtype!
275
+ # (*intX/floatX will be enforced!)
276
+ evaled = result .eval ()
277
+ expected_dtype = pm .smarttypeX (np .array (1 , dtype = input_dtype )).dtype
278
+ assert result .type .dtype == expected_dtype
279
+ assert evaled .dtype == np .dtype (expected_dtype )
268
280
269
281
270
282
def test_pandas_to_array_pandas_index ():
271
283
data = pd .Index ([1 , 2 , 3 ])
272
- result = convert_observed_data (data )
284
+ result = convert_data (data )
273
285
expected = np .array ([1 , 2 , 3 ])
274
286
np .testing .assert_array_equal (result , expected )
275
287
0 commit comments