1212from pytensor .compile import get_mode
1313from pytensor .compile .ops import deep_copy_op
1414from pytensor .gradient import grad
15- from pytensor .scalar import float64
15+ from pytensor .scalar import Composite , float64
1616from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
1717from pytensor .tensor .math import All , Any , Max , Min , Prod , ProdWithoutZeros , Sum
1818from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
@@ -548,7 +548,7 @@ def test_Argmax(x, axes, exc):
548548 )
549549
550550
551- def test_elemwise_out_type ():
551+ def test_elemwise_inplace_out_type ():
552552 # Create a graph with an elemwise
553553 # Ravel failes if the elemwise output type is reported incorrectly
554554 x = pt .matrix ()
@@ -563,6 +563,28 @@ def test_elemwise_out_type():
563563 assert func (x_val ).shape == (18 ,)
564564
565565
566+ def test_elemwise_multiple_inplace_outs ():
567+ x = pt .vector ()
568+ y = pt .vector ()
569+
570+ x_ = pt .scalar_from_tensor (x [0 ])
571+ y_ = pt .scalar_from_tensor (y [0 ])
572+ out_ = x_ + 1 , y_ + 1
573+
574+ composite_op = Composite ([x_ , y_ ], out_ )
575+ elemwise_op = Elemwise (composite_op , inplace_pattern = {0 : 0 , 1 : 1 })
576+ out = elemwise_op (x , y )
577+
578+ fn = function ([x , y ], out , mode = "NUMBA" , accept_inplace = True )
579+ x_test = np .array ([1 , 2 , 3 ], dtype = config .floatX )
580+ y_test = np .array ([4 , 5 , 6 ], dtype = config .floatX )
581+ out1 , out2 = fn (x_test , y_test )
582+ assert out1 is x_test
583+ assert out2 is y_test
584+ np .testing .assert_allclose (out1 , [2 , 3 , 4 ])
585+ np .testing .assert_allclose (out2 , [5 , 6 , 7 ])
586+
587+
566588def test_scalar_loop ():
567589 a = float64 ("a" )
568590 scalar_loop = pytensor .scalar .ScalarLoop ([a ], [a + a ])
0 commit comments