Skip to content

Commit 914e10f

Browse files
Add smarttypeX to apply intX or floatX as needed
This will be used to fix the `GeneratorAdapter` when applied to generators producing int-valued data.
1 parent 19be124 commit 914e10f

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

docs/source/api/pytensorf.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ PyTensor utils
1616
floatX
1717
intX
1818
smartfloatX
19+
smarttypeX
1920
constant_fold
2021
CallableTensor
2122
join_nonshared_inputs

pymc/pytensorf.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
"floatX",
6969
"intX",
7070
"smartfloatX",
71+
"smarttypeX",
7172
"jacobian",
7273
"CallableTensor",
7374
"join_nonshared_inputs",
@@ -297,6 +298,14 @@ def smartfloatX(x):
297298
return x
298299

299300

301+
def smarttypeX(x):
302+
if str(x.dtype).startswith("float"):
303+
x = floatX(x)
304+
elif str(x.dtype).startswith("int"):
305+
x = intX(x)
306+
return x
307+
308+
300309
"""
301310
PyTensor derivative functions
302311
"""

0 commit comments

Comments
 (0)