Skip to content

Commit 8cb57d6

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Delete xla_client.execute_with_python_values.
This is not a public API and exists only for testing. PiperOrigin-RevId: 681453343
1 parent 9dc8ec1 commit 8cb57d6

File tree

3 files changed

+29
-61
lines changed

3 files changed

+29
-61
lines changed

xla/python/xla_client.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -469,40 +469,6 @@ def computation_count():
469469
# There are different implementations of Executable for different backends.
470470

471471

472-
def execute_with_python_values(executable, arguments, backend):
473-
"""Execute on one replica with Python values as arguments and output."""
474-
475-
def put(arg):
476-
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
477-
478-
arguments = [put(arg) for arg in arguments]
479-
outputs = executable.execute(arguments)
480-
return [np.asarray(x) for x in outputs]
481-
482-
483-
def execute_with_python_values_replicated(executable, arguments, backend):
484-
"""Execute on many replicas with Python values as arguments and output.
485-
486-
Args:
487-
executable: the program to run.
488-
arguments: a list of lists of Python values indexed by `[replica][arg_num]`
489-
to pass as inputs.
490-
backend: the backend we are targeting.
491-
492-
Returns:
493-
A list of python values, one per replica.
494-
"""
495-
devices = executable.local_devices()
496-
497-
# pylint: disable=g-complex-comprehension
498-
def copy_to_devices(pyvals):
499-
return [backend.buffer_from_pyval(v, d) for v, d in zip(pyvals, devices)]
500-
501-
inputs = [copy_to_devices(pyvals) for pyvals in zip(*arguments)]
502-
outputs = executable.execute_sharded_on_local_devices(inputs)
503-
return [[np.asarray(x) for x in xs] for xs in zip(*outputs)]
504-
505-
506472
class PaddingType(enum.Enum):
507473
VALID = 1
508474
SAME = 2

xla/python/xla_client.pyi

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,6 @@ _NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]]
7171
def dtype_to_etype(dtype: numpy.dtype) -> PrimitiveType:
7272
...
7373

74-
def execute_with_python_values(executable: LoadedExecutable, arguments: Sequence[Any],
75-
backend: Client) -> Sequence[numpy.ndarray]: ...
76-
77-
def execute_with_python_values_replicated(
78-
executable: LoadedExecutable, arguments: Sequence[Sequence[Any]],
79-
backend: Client) -> Sequence[Sequence[numpy.ndarray]]: ...
80-
8174
def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: ...
8275

8376
def heap_profile(client: Client) -> bytes:

xla/python/xla_client_test.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@
6464
xla_client._xla.mlir.xla_computation_to_mlir_module)
6565

6666

67+
def execute_with_python_values(executable, arguments, backend): # pylint: disable=invalid-name
68+
"""Execute on one replica with Python values as arguments and output."""
69+
70+
def put(arg): # pylint: disable=invalid-name
71+
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
72+
73+
arguments = [put(arg) for arg in arguments]
74+
outputs = executable.execute(arguments)
75+
return [np.asarray(x) for x in outputs]
76+
77+
6778
# pylint: disable=invalid-name
6879
def jax_array_convert_to_array(self, dtype=None, copy=None):
6980
del copy
@@ -164,7 +175,7 @@ def _NewComputation(self, name=None):
164175
def _Execute(self, c, arguments):
165176
compiled_c = self.backend.compile(
166177
xla_computation_to_mlir_module(c.build()))
167-
return xla_client.execute_with_python_values(
178+
return execute_with_python_values(
168179
compiled_c, arguments, backend=self.backend)
169180

170181
def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
@@ -596,7 +607,7 @@ def testExecuteFromProto(self):
596607
# Load and execute the proto
597608
c = xla_client.XlaComputation(serialized_proto)
598609
m = xla_computation_to_mlir_module(c)
599-
ans, = xla_client.execute_with_python_values(
610+
ans, = execute_with_python_values(
600611
self.backend.compile(m), (), backend=self.backend)
601612
np.testing.assert_equal(ans, np.int32(3))
602613

@@ -1245,7 +1256,7 @@ def testConvertElementType(self, src_dtype, dst_dtype):
12451256
ops.ConvertElementType(
12461257
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
12471258

1248-
result = xla_client.execute_with_python_values(
1259+
result = execute_with_python_values(
12491260
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
12501261
backend=self.backend)
12511262
self.assertLen(result, 1)
@@ -1275,7 +1286,7 @@ def testBitcastConvertType(self, src_dtype, dst_dtype):
12751286
ops.BitcastConvertType(
12761287
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
12771288

1278-
result = xla_client.execute_with_python_values(
1289+
result = execute_with_python_values(
12791290
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
12801291
backend=self.backend)
12811292
self.assertLen(result, 1)
@@ -1859,7 +1870,7 @@ def testTuple(self):
18591870
ops.Constant(c, NumpyArrayF32([1.0, 2.0])),
18601871
ops.Constant(c, NumpyArrayBool([True, False, False, True]))
18611872
])
1862-
result = xla_client.execute_with_python_values(
1873+
result = execute_with_python_values(
18631874
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
18641875
backend=self.backend)
18651876
self.assertLen(result, 3)
@@ -1899,7 +1910,7 @@ def testRngNormal(self):
18991910
ops.Constant(c, NumpyArrayF32(1.)),
19001911
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
19011912
shape))
1902-
result = xla_client.execute_with_python_values(
1913+
result = execute_with_python_values(
19031914
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
19041915
backend=self.backend)
19051916
# since the result is random, we just check shape and uniqueness
@@ -1916,7 +1927,7 @@ def testRngUniformF32(self):
19161927
ops.Constant(c, NumpyArrayF32(hi)),
19171928
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
19181929
shape))
1919-
result = xla_client.execute_with_python_values(
1930+
result = execute_with_python_values(
19201931
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
19211932
backend=self.backend)
19221933
# since the result is random, we just check shape, uniqueness, and range
@@ -1935,7 +1946,7 @@ def testRngUniformS32(self):
19351946
ops.Constant(c, NumpyArrayS32(hi)),
19361947
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
19371948
shape))
1938-
result = xla_client.execute_with_python_values(
1949+
result = execute_with_python_values(
19391950
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
19401951
backend=self.backend)
19411952
# since the result is random, we just check shape, integrality, and range
@@ -1965,7 +1976,7 @@ def testSortKeyVal(self):
19651976
values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
19661977
c = self._NewComputation()
19671978
ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0)
1968-
result = xla_client.execute_with_python_values(
1979+
result = execute_with_python_values(
19691980
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
19701981
backend=self.backend)
19711982
self.assertLen(result, 2)
@@ -1988,7 +1999,7 @@ def testSortCustomComparator(self):
19881999
c, (ops.Constant(c, keys), ops.Constant(c, values)),
19892000
dimension=1,
19902001
comparator=comparator)
1991-
result = xla_client.execute_with_python_values(
2002+
result = execute_with_python_values(
19922003
self.backend.compile(xla_computation_to_mlir_module(c.build())), (),
19932004
backend=self.backend)
19942005
self.assertLen(result, 2)
@@ -2578,7 +2589,7 @@ def testInfeedS32Values(self):
25782589
device.transfer_to_infeed(item)
25792590

25802591
for item in to_infeed:
2581-
result, = xla_client.execute_with_python_values(
2592+
result, = execute_with_python_values(
25822593
compiled_c, (), backend=self.backend)
25832594
self.assertEqual(result, item)
25842595

@@ -2597,7 +2608,7 @@ def testInfeedTuple(self):
25972608
device = self.backend.local_devices()[0]
25982609
device.transfer_to_infeed(to_infeed)
25992610

2600-
result = xla_client.execute_with_python_values(
2611+
result = execute_with_python_values(
26012612
compiled_c, (), backend=self.backend)
26022613
self.assertLen(result, 2)
26032614
np.testing.assert_equal(result[0], to_infeed[0])
@@ -2741,7 +2752,7 @@ def testInvokeWithWrongElementType(self):
27412752
c.clear_op_metadata()
27422753

27432754
def TestFun():
2744-
return xla_client.execute_with_python_values(
2755+
return execute_with_python_values(
27452756
self.backend.compile(xla_computation_to_mlir_module(c.build())),
27462757
[self.f32_scalar_2], self.backend)
27472758

@@ -2763,7 +2774,7 @@ def testComputationRootDifferentFromLastOp(self):
27632774
arg = NumpyArrayF32(1.0)
27642775
compiled_c = self.backend.compile(
27652776
xla_computation_to_mlir_module(c.build(result)))
2766-
ans, = xla_client.execute_with_python_values(
2777+
ans, = execute_with_python_values(
27672778
compiled_c, [arg], backend=self.backend)
27682779
np.testing.assert_allclose(ans, 4.14)
27692780

@@ -2787,7 +2798,7 @@ def testSetSharding(self):
27872798
arg = NumpyArrayF32(1.0)
27882799
compiled_c = self.backend.compile(
27892800
xla_computation_to_mlir_module(c.build(result)))
2790-
ans, = xla_client.execute_with_python_values(
2801+
ans, = execute_with_python_values(
27912802
compiled_c, [arg], backend=self.backend)
27922803
np.testing.assert_allclose(ans, 4.14)
27932804

@@ -3128,7 +3139,7 @@ def testHloProgramViaIfrtProgram(self):
31283139
)
31293140

31303141
compiled_c = self.backend.compile_ifrt_program(program, options)
3131-
results = xla_client.execute_with_python_values(
3142+
results = execute_with_python_values(
31323143
compiled_c, arguments=(), backend=self.backend
31333144
)
31343145

@@ -3154,10 +3165,8 @@ def testExecutableSerialization(self):
31543165
serialized = self.backend.serialize_executable(executable)
31553166
deserialized = self.backend.deserialize_executable(serialized, options)
31563167

3157-
expected, = xla_client.execute_with_python_values(executable, (),
3158-
self.backend)
3159-
actual, = xla_client.execute_with_python_values(deserialized, (),
3160-
self.backend)
3168+
expected, = execute_with_python_values(executable, (), self.backend)
3169+
actual, = execute_with_python_values(deserialized, (), self.backend)
31613170
self.assertTrue(np.all(actual == expected))
31623171

31633172
def testCompileOptionsSerialization(self):

0 commit comments

Comments
 (0)