1+ from itertools import chain
2+
13import torch
24
35from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
6+ from pytensor .scalar import ScalarLoop
47from pytensor .tensor .elemwise import DimShuffle , Elemwise
58from pytensor .tensor .math import All , Any , Max , Min , Prod , Sum
69from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
@@ -17,6 +20,34 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
1720 def elemwise_fn (* inputs ):
1821 Elemwise ._check_runtime_broadcast (node , inputs )
1922 return base_fn (* inputs )
23+
24+ elif isinstance (scalar_op , ScalarLoop ):
25+ # note: scalarloop + elemwise is too common
26+ # to not work, but @1031, vmap won't allow it.
27+ # Instead, we will just successively unbind
28+ def elemwise_fn (* inputs ):
29+ Elemwise ._check_runtime_broadcast (node , inputs )
30+ shaped_inputs = torch .broadcast_tensors (* inputs )
31+ expected_size = shaped_inputs [0 ].numel ()
32+ final_inputs = [s .clone () for s in shaped_inputs ]
33+ for _ in range (shaped_inputs [0 ].dim () - 1 ):
34+ for i , _ in enumerate (shaped_inputs ):
35+ layer = chain .from_iterable ([s .unbind (0 ) for s in final_inputs [i ]])
36+ final_inputs [i ] = list (layer )
37+
38+ # make sure we still have the same number of things
39+ assert len (final_inputs ) == len (shaped_inputs )
40+
41+ # make sure each group of things are the expected size
42+ assert all (len (x ) == expected_size for x in final_inputs )
43+
44+ # make sure they are all single elements
45+ assert all (len (x .shape ) == 0 for tensor in final_inputs for x in tensor )
46+ res = [base_fn (* args ) for args in zip (* final_inputs )]
47+ states = torch .stack (tuple (out [0 ] for out in res ))
48+ done = torch .stack (tuple (out [1 ] for out in res ))
49+ return states , done
50+
2051 else :
2152
2253 def elemwise_fn (* inputs ):
@@ -26,6 +57,7 @@ def elemwise_fn(*inputs):
2657 for _ in range (broadcast_inputs [0 ].dim ()):
2758 ufunc = torch .vmap (ufunc )
2859 return ufunc (* broadcast_inputs )
60+ return base_fn (* inputs )
2961
3062 return elemwise_fn
3163
0 commit comments