File tree Expand file tree Collapse file tree 1 file changed +15
-1
lines changed
pytensor/link/pytorch/dispatch Expand file tree Collapse file tree 1 file changed +15
-1
lines changed Original file line number Diff line number Diff line change 1+ import importlib
2+
13import torch
24
35from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
1012def pytorch_funcify_Elemwise (op , node , ** kwargs ):
1113 scalar_op = op .scalar_op
1214 base_fn = pytorch_funcify (scalar_op , node = node , ** kwargs )
15+
16+ def check_special_scipy (func_name ):
17+ if "scipy." not in func_name :
18+ return False
19+ loc = func_name .split ("." )[1 :]
20+ try :
21+ mod = importlib .import_module ("." .join (loc [:- 1 ]), "torch" )
22+ return getattr (mod , loc [- 1 ], False )
23+ except ImportError :
24+ return False
25+
1326 if hasattr (scalar_op , "nfunc_spec" ) and (
14- hasattr (torch , scalar_op .nfunc_spec [0 ]) or "scipy." in scalar_op .nfunc_spec [0 ]
27+ hasattr (torch , scalar_op .nfunc_spec [0 ])
28+ or check_special_scipy (scalar_op .nfunc_spec [0 ])
1529 ):
1630 # torch can handle this scalar
1731 # broadcast, we'll let it.
You can’t perform that action at this time.
0 commit comments