1313@pytorch_funcify .register (Elemwise )
1414def pytorch_funcify_Elemwise (op , node , ** kwargs ):
1515 scalar_op = op .scalar_op
16+
1617 base_fn = pytorch_funcify (scalar_op , node = node , ** kwargs )
1718
1819 def check_special_scipy (func_name ):
@@ -36,31 +37,7 @@ def elemwise_fn(*inputs):
3637 return base_fn (* inputs )
3738
3839 elif isinstance (scalar_op , ScalarLoop ):
39- # note: scalarloop + elemwise is too common
40- # to not work, but @1031, vmap won't allow it.
41- # Instead, we will just successively unbind
42- def elemwise_fn (* inputs ):
43- Elemwise ._check_runtime_broadcast (node , inputs )
44- shaped_inputs = torch .broadcast_tensors (* inputs )
45- expected_size = shaped_inputs [0 ].numel ()
46- final_inputs = [s .clone () for s in shaped_inputs ]
47- for _ in range (shaped_inputs [0 ].dim () - 1 ):
48- for i , _ in enumerate (shaped_inputs ):
49- layer = chain .from_iterable ([s .unbind (0 ) for s in final_inputs [i ]])
50- final_inputs [i ] = list (layer )
51-
52- # make sure we still have the same number of things
53- assert len (final_inputs ) == len (shaped_inputs )
54-
55- # make sure each group of things are the expected size
56- assert all (len (x ) == expected_size for x in final_inputs )
57-
58- # make sure they are all single elements
59- assert all (len (x .shape ) == 0 for tensor in final_inputs for x in tensor )
60- res = [base_fn (* args ) for args in zip (* final_inputs )]
61- states = torch .stack (tuple (out [0 ] for out in res ))
62- done = torch .stack (tuple (out [1 ] for out in res ))
63- return states , done
40+ return elemwise_scalar_loop (base_fn , op , node , ** kwargs )
6441
6542 else :
6643
@@ -206,3 +183,56 @@ def softmax_grad(dy, sm):
206183 return dy_times_sm - torch .sum (dy_times_sm , dim = axis , keepdim = True ) * sm
207184
208185 return softmax_grad
186+
187+
188+ def elemwise_scalar_loop (base_fn , op , node , ** kwargs ):
189+ """
190+ ScalarLoop + Elemwise is too common
191+ to not work, but @1031, vmap won't allow it.
192+ Instead, we can do the following strategy
193+ 1. `.unbind(dim)` will return a list of tensors
194+ representing `dim` but "unwrapped". e.x.
195+ ```
196+ t = torch.ones(3, 4, 2)
197+ len(t.unbind(0)) == 3
198+ t[0].shape == torch.Size[4, 2]
199+ 2. If we successfully apply, the length of the list will grow
200+ by the next dimension in the tensor if we flatten the previous
201+ dimension result
202+ ```
203+ inputs = [torch.ones(3, 4, 2)]
204+ level_1 = chain.from_iterable(t.unbind(0) for t in inputs)
205+ level_2 = chain.from_iterable(t.unbind(0) for t in level_1)
206+ len(level_2) == 3 * 4
207+ ```
208+ 3. Eventually we'll reach single dimension tensors. At that point
209+ we can iterate over each input in an element by element manner
210+ and call some function
211+
212+ For scalar loop, we need to broadcast the tensors so all
213+ the necessary values are repeated, and we "evenly" iterate through everything
214+ """
215+
216+ def elemwise_fn (* inputs ):
217+ Elemwise ._check_runtime_broadcast (node , inputs )
218+ shaped_inputs = torch .broadcast_tensors (* inputs )
219+ expected_size = shaped_inputs [0 ].numel ()
220+ final_inputs = [s .clone () for s in shaped_inputs ]
221+ for _ in range (shaped_inputs [0 ].dim () - 1 ):
222+ for i , _ in enumerate (shaped_inputs ):
223+ layer = chain .from_iterable ([s .unbind (0 ) for s in final_inputs [i ]])
224+ final_inputs [i ] = list (layer )
225+
226+ # make sure we still have the same number of things
227+ assert len (final_inputs ) == len (shaped_inputs )
228+
229+ # make sure each group of things are the expected size
230+ assert all (len (x ) == expected_size for x in final_inputs )
231+
232+ # make sure they are all single elements
233+ assert all (len (x .shape ) == 0 for tensor in final_inputs for x in tensor )
234+ res = [base_fn (* args ) for args in zip (* final_inputs )]
235+
236+ return [torch .stack (tuple (out [i ] for out in res )) for i in range (len (res [0 ]))]
237+
238+ return elemwise_fn
0 commit comments