1+ from itertools import chain
2+
13import torch
24
35from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
6+ from pytensor .scalar import ScalarLoop
47from pytensor .tensor .elemwise import DimShuffle , Elemwise
58from pytensor .tensor .math import All , Any , Max , Min , Prod , Sum
69from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
@@ -11,9 +14,38 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
1114 scalar_op = op .scalar_op
1215 base_fn = pytorch_funcify (scalar_op , node = node , ** kwargs )
1316
14- def elemwise_fn (* inputs ):
15- Elemwise ._check_runtime_broadcast (node , inputs )
16- return base_fn (* inputs )
17+ if isinstance (scalar_op , ScalarLoop ):
18+ # note: scalarloop + elemwise is too common
19+ # to not work, but @1031, vmap won't allow it.
20+ # Instead, we will just successively unbind
21+ def elemwise_fn (* inputs ):
22+ Elemwise ._check_runtime_broadcast (node , inputs )
23+ shaped_inputs = torch .broadcast_tensors (* inputs )
24+ expected_size = shaped_inputs [0 ].numel ()
25+ final_inputs = [s .clone () for s in shaped_inputs ]
26+ for _ in range (shaped_inputs [0 ].dim () - 1 ):
27+ for i , _ in enumerate (shaped_inputs ):
28+ layer = chain .from_iterable ([s .unbind (0 ) for s in final_inputs [i ]])
29+ final_inputs [i ] = list (layer )
30+
31+ # make sure we still have the same number of things
32+ assert len (final_inputs ) == len (shaped_inputs )
33+
34+ # make sure each group of things are the expected size
35+ assert all (len (x ) == expected_size for x in final_inputs )
36+
37+ # make sure they are all single elements
38+ assert all (len (x .shape ) == 0 for tensor in final_inputs for x in tensor )
39+ res = [base_fn (* args ) for args in zip (* final_inputs )]
40+ states = torch .stack (tuple (out [0 ] for out in res ))
41+ done = torch .stack (tuple (out [1 ] for out in res ))
42+ return states , done
43+
44+ else :
45+
46+ def elemwise_fn (* inputs ):
47+ Elemwise ._check_runtime_broadcast (node , inputs )
48+ return base_fn (* inputs )
1749
1850 return elemwise_fn
1951
0 commit comments