Skip to content

Commit 3447619

Browse files
authored
Add a type guard for intX (#4569)
* add type guard for inX * fix test for pandas * fix posterior test, ints passed for float data * release notes * Update RELEASE-NOTES.md * Update RELEASE-NOTES.md
1 parent 106cd8f commit 3447619

File tree

4 files changed

+8
-3
lines changed

4 files changed

+8
-3
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
- `pm.make_shared_replacements` now retains broadcasting information which fixes issues with Metropolis samplers (see [#4492](https://github.com/pymc-devs/pymc3/pull/4492)).
2525

2626
**Release manager** for 3.11.2: Michael Osthege ([@michaelosthege](https://github.com/michaelosthege))
27+
- `pm.intX` no longer downcasts integers unnecessarily (see [#4569](https://github.com/pymc-devs/pymc3/pull/4569))
2728

2829
## PyMC3 3.11.1 (12 February 2021)
2930

pymc3/tests/test_data_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_sample_posterior_predictive_after_set_data(self):
7878
trace = pm.sample(1000, tune=1000, chains=1)
7979
# Predict on new data.
8080
with model:
81-
x_test = [5, 6, 9]
81+
x_test = [5.0, 6.0, 9.0]
8282
pm.set_data(new_data={"x": x_test})
8383
y_test = pm.sample_posterior_predictive(trace)
8484
y_test1 = pm.fast_sample_posterior_predictive(trace)

pymc3/tests/test_model_helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_pandas_to_array(self, input_dtype):
8282
assert isinstance(theano_output, theano.graph.basic.Variable)
8383
npt.assert_allclose(theano_output.eval(), theano_graph_input.eval())
8484
intX = pm.theanof._conversion_map[theano.config.floatX]
85-
if dense_input.dtype == intX or dense_input.dtype == theano.config.floatX:
85+
if "int" in str(dense_input.dtype) or dense_input.dtype == theano.config.floatX:
8686
assert theano_output.owner is None # func should not have added new nodes
8787
assert theano_output.name == input_name
8888
else:
@@ -92,7 +92,8 @@ def test_pandas_to_array(self, input_dtype):
9292
if "float" in input_dtype:
9393
assert theano_output.dtype == theano.config.floatX
9494
else:
95-
assert theano_output.dtype == intX
95+
# only cast floats, leave ints as is
96+
assert theano_output.dtype == input_dtype
9697

9798
# Check function behavior with generator data
9899
generator_output = func(square_generator)

pymc3/theanof.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def intX(X):
9393
"""
9494
Convert a theano tensor or numpy array to theano.tensor.int32 type.
9595
"""
96+
# check value is already int, do nothing in this case
97+
if (hasattr(X, "dtype") and "int" in str(X.dtype)) or isinstance(X, int):
98+
return X
9699
intX = _conversion_map[theano.config.floatX]
97100
try:
98101
return X.astype(intX)

0 commit comments

Comments
 (0)