6464 xla_client ._xla .mlir .xla_computation_to_mlir_module )
6565
6666
67+ def execute_with_python_values (executable , arguments , backend ): # pylint: disable=invalid-name
68+ """Execute on one replica with Python values as arguments and output."""
69+
70+ def put (arg ): # pylint: disable=invalid-name
71+ return backend .buffer_from_pyval (arg , device = executable .local_devices ()[0 ])
72+
73+ arguments = [put (arg ) for arg in arguments ]
74+ outputs = executable .execute (arguments )
75+ return [np .asarray (x ) for x in outputs ]
76+
77+
6778# pylint: disable=invalid-name
6879def jax_array_convert_to_array (self , dtype = None , copy = None ):
6980 del copy
@@ -164,7 +175,7 @@ def _NewComputation(self, name=None):
164175 def _Execute (self , c , arguments ):
165176 compiled_c = self .backend .compile (
166177 xla_computation_to_mlir_module (c .build ()))
167- return xla_client . execute_with_python_values (
178+ return execute_with_python_values (
168179 compiled_c , arguments , backend = self .backend )
169180
170181 def _ExecuteAndAssertWith (self , assert_func , c , arguments , expected ):
@@ -596,7 +607,7 @@ def testExecuteFromProto(self):
596607 # Load and execute the proto
597608 c = xla_client .XlaComputation (serialized_proto )
598609 m = xla_computation_to_mlir_module (c )
599- ans , = xla_client . execute_with_python_values (
610+ ans , = execute_with_python_values (
600611 self .backend .compile (m ), (), backend = self .backend )
601612 np .testing .assert_equal (ans , np .int32 (3 ))
602613
@@ -1245,7 +1256,7 @@ def testConvertElementType(self, src_dtype, dst_dtype):
12451256 ops .ConvertElementType (
12461257 ops .Constant (c , x ), xla_client .dtype_to_etype (dst_dtype ))
12471258
1248- result = xla_client . execute_with_python_values (
1259+ result = execute_with_python_values (
12491260 self .backend .compile (xla_computation_to_mlir_module (c .build ())), (),
12501261 backend = self .backend )
12511262 self .assertLen (result , 1 )
@@ -1275,7 +1286,7 @@ def testBitcastConvertType(self, src_dtype, dst_dtype):
12751286 ops .BitcastConvertType (
12761287 ops .Constant (c , x ), xla_client .dtype_to_etype (dst_dtype ))
12771288
1278- result = xla_client . execute_with_python_values (
1289+ result = execute_with_python_values (
12791290 self .backend .compile (xla_computation_to_mlir_module (c .build ())), (),
12801291 backend = self .backend )
12811292 self .assertLen (result , 1 )
@@ -1859,7 +1870,7 @@ def testTuple(self):
18591870 ops .Constant (c , NumpyArrayF32 ([1.0 , 2.0 ])),
18601871 ops .Constant (c , NumpyArrayBool ([True , False , False , True ]))
18611872 ])
1862- result = xla_client . execute_with_python_values (
1873+ result = execute_with_python_values (
18631874 self .backend .compile (xla_computation_to_mlir_module (c .build ())), (),
18641875 backend = self .backend )
18651876 self .assertLen (result , 3 )
@@ -1899,7 +1910,7 @@ def testRngNormal(self):
18991910 ops .Constant (c , NumpyArrayF32 (1. )),
19001911 shape = xla_client .Shape .array_shape (xla_client .PrimitiveType .F32 ,
19011912 shape ))
1902- result = xla_client . execute_with_python_values (
1913+ result = execute_with_python_values (
19031914 self .backend .compile (xla_computation_to_mlir_module (c .build ())), (),
19041915 backend = self .backend )
19051916 # since the result is random, we just check shape and uniqueness
@@ -1916,7 +1927,7 @@ def testRngUniformF32(self):
19161927 ops .Constant (c , NumpyArrayF32 (hi )),
19171928 shape = xla_client .Shape .array_shape (xla_client .PrimitiveType .F32 ,
19181929 shape ))
1919- result = xla_client . execute_with_python_values (
1930+ result = execute_with_python_values (
19201931 self .backend .compile (xla_computation_to_mlir_module (c .build ())), (),
19211932 backend = self .backend )
19221933 # since the result is random, we just check shape, uniqueness, and range
@@ -1935,7 +1946,7 @@ def testRngUniformS32(self):
19351946 ops .Constant (c , NumpyArrayS32 (hi )),
19361947 shape = xla_client .Shape .array_shape (xla_client .PrimitiveType .S32 ,
19371948 shape ))
1938- result = xla_client . execute_with_python_values (
1949+ result = execute_with_python_values (
19391950 self .backend .compile (xla_computation_to_mlir_module (c .build ())), (),
19401951 backend = self .backend )
19411952 # since the result is random, we just check shape, integrality, and range
@@ -1965,7 +1976,7 @@ def testSortKeyVal(self):
19651976 values = np .array ([[0 , 1 , 2 , 3 ], [4 , 5 , 6 , 7 ]], dtype = np .int32 )
19661977 c = self ._NewComputation ()
19671978 ops .Sort (c , (ops .Constant (c , keys ), ops .Constant (c , values )), dimension = 0 )
1968- result = xla_client . execute_with_python_values (
1979+ result = execute_with_python_values (
19691980 self .backend .compile (xla_computation_to_mlir_module (c .build ())), (),
19701981 backend = self .backend )
19711982 self .assertLen (result , 2 )
@@ -1988,7 +1999,7 @@ def testSortCustomComparator(self):
19881999 c , (ops .Constant (c , keys ), ops .Constant (c , values )),
19892000 dimension = 1 ,
19902001 comparator = comparator )
1991- result = xla_client . execute_with_python_values (
2002+ result = execute_with_python_values (
19922003 self .backend .compile (xla_computation_to_mlir_module (c .build ())), (),
19932004 backend = self .backend )
19942005 self .assertLen (result , 2 )
@@ -2578,7 +2589,7 @@ def testInfeedS32Values(self):
25782589 device .transfer_to_infeed (item )
25792590
25802591 for item in to_infeed :
2581- result , = xla_client . execute_with_python_values (
2592+ result , = execute_with_python_values (
25822593 compiled_c , (), backend = self .backend )
25832594 self .assertEqual (result , item )
25842595
@@ -2597,7 +2608,7 @@ def testInfeedTuple(self):
25972608 device = self .backend .local_devices ()[0 ]
25982609 device .transfer_to_infeed (to_infeed )
25992610
2600- result = xla_client . execute_with_python_values (
2611+ result = execute_with_python_values (
26012612 compiled_c , (), backend = self .backend )
26022613 self .assertLen (result , 2 )
26032614 np .testing .assert_equal (result [0 ], to_infeed [0 ])
@@ -2741,7 +2752,7 @@ def testInvokeWithWrongElementType(self):
27412752 c .clear_op_metadata ()
27422753
27432754 def TestFun ():
2744- return xla_client . execute_with_python_values (
2755+ return execute_with_python_values (
27452756 self .backend .compile (xla_computation_to_mlir_module (c .build ())),
27462757 [self .f32_scalar_2 ], self .backend )
27472758
@@ -2763,7 +2774,7 @@ def testComputationRootDifferentFromLastOp(self):
27632774 arg = NumpyArrayF32 (1.0 )
27642775 compiled_c = self .backend .compile (
27652776 xla_computation_to_mlir_module (c .build (result )))
2766- ans , = xla_client . execute_with_python_values (
2777+ ans , = execute_with_python_values (
27672778 compiled_c , [arg ], backend = self .backend )
27682779 np .testing .assert_allclose (ans , 4.14 )
27692780
@@ -2787,7 +2798,7 @@ def testSetSharding(self):
27872798 arg = NumpyArrayF32 (1.0 )
27882799 compiled_c = self .backend .compile (
27892800 xla_computation_to_mlir_module (c .build (result )))
2790- ans , = xla_client . execute_with_python_values (
2801+ ans , = execute_with_python_values (
27912802 compiled_c , [arg ], backend = self .backend )
27922803 np .testing .assert_allclose (ans , 4.14 )
27932804
@@ -3128,7 +3139,7 @@ def testHloProgramViaIfrtProgram(self):
31283139 )
31293140
31303141 compiled_c = self .backend .compile_ifrt_program (program , options )
3131- results = xla_client . execute_with_python_values (
3142+ results = execute_with_python_values (
31323143 compiled_c , arguments = (), backend = self .backend
31333144 )
31343145
@@ -3154,10 +3165,8 @@ def testExecutableSerialization(self):
31543165 serialized = self .backend .serialize_executable (executable )
31553166 deserialized = self .backend .deserialize_executable (serialized , options )
31563167
3157- expected , = xla_client .execute_with_python_values (executable , (),
3158- self .backend )
3159- actual , = xla_client .execute_with_python_values (deserialized , (),
3160- self .backend )
3168+ expected , = execute_with_python_values (executable , (), self .backend )
3169+ actual , = execute_with_python_values (deserialized , (), self .backend )
31613170 self .assertTrue (np .all (actual == expected ))
31623171
31633172 def testCompileOptionsSerialization (self ):
0 commit comments