11import importlib
2+ from itertools import chain
23
34import torch
45
56from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
7+ from pytensor .scalar import ScalarLoop
68from pytensor .tensor .elemwise import DimShuffle , Elemwise
79from pytensor .tensor .math import All , Any , Max , Min , Prod , Sum
810from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
@@ -33,6 +35,33 @@ def elemwise_fn(*inputs):
3335 Elemwise ._check_runtime_broadcast (node , inputs )
3436 return base_fn (* inputs )
3537
38+ elif isinstance (scalar_op , ScalarLoop ):
39+ # note: scalarloop + elemwise is too common
40+ # to not work, but @1031, vmap won't allow it.
41+ # Instead, we will just successively unbind
42+ def elemwise_fn (* inputs ):
43+ Elemwise ._check_runtime_broadcast (node , inputs )
44+ shaped_inputs = torch .broadcast_tensors (* inputs )
45+ expected_size = shaped_inputs [0 ].numel ()
46+ final_inputs = [s .clone () for s in shaped_inputs ]
47+ for _ in range (shaped_inputs [0 ].dim () - 1 ):
48+ for i , _ in enumerate (shaped_inputs ):
49+ layer = chain .from_iterable ([s .unbind (0 ) for s in final_inputs [i ]])
50+ final_inputs [i ] = list (layer )
51+
52+ # make sure we still have the same number of things
53+ assert len (final_inputs ) == len (shaped_inputs )
54+
55+ # make sure each group of things are the expected size
56+ assert all (len (x ) == expected_size for x in final_inputs )
57+
58+ # make sure they are all single elements
59+ assert all (len (x .shape ) == 0 for tensor in final_inputs for x in tensor )
60+ res = [base_fn (* args ) for args in zip (* final_inputs )]
61+ states = torch .stack (tuple (out [0 ] for out in res ))
62+ done = torch .stack (tuple (out [1 ] for out in res ))
63+ return states , done
64+
3665 else :
3766
3867 def elemwise_fn (* inputs ):
@@ -42,6 +71,7 @@ def elemwise_fn(*inputs):
4271 for _ in range (broadcast_inputs [0 ].dim ()):
4372 ufunc = torch .vmap (ufunc )
4473 return ufunc (* broadcast_inputs )
74+ return base_fn (* inputs )
4575
4676 return elemwise_fn
4777
0 commit comments