We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
smarttypeX
1 parent 19be124 commit 914e10fCopy full SHA for 914e10f
docs/source/api/pytensorf.rst
@@ -16,6 +16,7 @@ PyTensor utils
16
floatX
17
intX
18
smartfloatX
19
+ smarttypeX
20
constant_fold
21
CallableTensor
22
join_nonshared_inputs
pymc/pytensorf.py
@@ -68,6 +68,7 @@
68
"floatX",
69
"intX",
70
"smartfloatX",
71
+ "smarttypeX",
72
"jacobian",
73
"CallableTensor",
74
"join_nonshared_inputs",
@@ -297,6 +298,14 @@ def smartfloatX(x):
297
298
return x
299
300
301
+def smarttypeX(x):
302
+ if str(x.dtype).startswith("float"):
303
+ x = floatX(x)
304
+ elif str(x.dtype).startswith("int"):
305
+ x = intX(x)
306
+ return x
307
+
308
309
"""
310
PyTensor derivative functions
311
0 commit comments