6161from jax ._src .interpreters import xla
6262from jax ._src .layout import DeviceLocalLayout , AutoLayout , Layout
6363from jax ._src .lib import xla_client as xc
64+ from jax ._src .lib import xla_extension_version
6465from jax ._src .lib .mlir import ir
6566from jax ._src .lib .mlir .dialects import hlo
6667from jax ._src .partition_spec import PartitionSpec
@@ -105,44 +106,69 @@ class WeakRefList(list):
105106
106107### util
107108
109+
110+ def to_xc_copy_semantics (copy_semantics ):
111+ if xla_extension_version < 296 :
112+ return [None ] * len (copy_semantics )
113+ out = []
114+ for cs in copy_semantics :
115+ if cs is None or cs == dispatch .CopySemantics .ALIAS :
116+ out .append (xc .ArrayCopySemantics .REUSE_INPUT )
117+ elif cs == dispatch .CopySemantics .COPY :
118+ out .append (xc .ArrayCopySemantics .ALWAYS_COPY )
119+ elif cs == dispatch .CopySemantics .DONATE :
120+ out .append (xc .ArrayCopySemantics .DONATE_INPUT )
121+ else :
122+ assert isinstance (cs , xc .ArrayCopySemantics )
123+ out .append (cs )
124+ return out
125+
126+
108127def identity (x ): return x
109128
110129@profiler .annotate_function
111- def shard_args (shardings : Sequence [JSharding ], layouts , args ,
112- canonicalize = True ) -> Sequence [xc .ArrayImpl ]:
130+ def shard_args (shardings : Sequence [JSharding ], layouts , copy_semantics ,
131+ args , canonicalize = True ) -> Sequence [xc .ArrayImpl ]:
132+ xc_copy_semantics = to_xc_copy_semantics (copy_semantics )
133+ del copy_semantics
113134 # Fast path for one argument.
114135 if len (args ) == 1 :
115136 arg = args [0 ]
116137 if canonicalize :
117138 arg = xla .canonicalize_dtype (arg )
118- return shard_arg_handlers [type (arg )]([arg ], shardings , layouts )
119-
120- # type(arg) -> (list[indices], list[args], list[shardings])
121- batches = collections .defaultdict (lambda : ([], [], [], [])) # type: ignore
122- for i , (arg , sharding , layout ) in enumerate (safe_zip (args , shardings , layouts )):
139+ return shard_arg_handlers [type (arg )]([arg ], shardings , layouts ,
140+ xc_copy_semantics )
141+
142+ # type(arg) -> (list[indices], list[args], list[shardings], list[layouts],
143+ # list[copy_semantics])
144+ batches = collections .defaultdict (lambda : ([], [], [], [], [])) # type: ignore
145+ for i , (arg , sharding , layout , cs ) in enumerate (
146+ safe_zip (args , shardings , layouts , xc_copy_semantics )):
123147 if canonicalize :
124148 arg = xla .canonicalize_dtype (arg )
125149 batch = batches [type (arg )]
126150 batch [0 ].append (i )
127151 batch [1 ].append (arg )
128152 batch [2 ].append (sharding )
129153 batch [3 ].append (layout )
154+ batch [4 ].append (cs )
130155
131156 # Call `shard_arg_handlers` per batch and build a flat list of arrays returned
132157 # from each call in the same order as `args`. Since `batches` is grouped by
133158 # types, we cannot simply flatten the results and we have to use the original
134159 # indices to put each array back to its original position.
135160 results : list [jax .Array | None ] = [None ] * len (args )
136- for t , (indices , a , s , l ) in batches .items ():
137- outs = shard_arg_handlers [t ](a , s , l )
161+ for t , (indices , a , s , l , cs ) in batches .items ():
162+ outs = shard_arg_handlers [t ](a , s , l , cs )
138163 for i , out in safe_zip (indices , outs ):
139164 results [i ] = out
140165 assert all (result is not None for result in results )
141166 return results
142167
143168
144169shard_arg_handlers : dict [
145- Any , Callable [[Sequence [Any ], Sequence [Any ], Sequence [Any ]], Sequence [Any ]]
170+ Any , Callable [[Sequence [Any ], Sequence [Any ], Sequence [Any ], Sequence [Any ]],
171+ Sequence [Any ]]
146172] = {}
147173
148174
@@ -172,12 +198,12 @@ def is_default_layout(curr_layout, sharding, aval):
172198 raise
173199
174200
175- def _masked_array_error (xs , shardings , layouts ):
201+ def _masked_array_error (xs , shardings , layouts , copy_semantics ):
176202 raise ValueError ("numpy masked arrays are not supported as direct inputs to JAX functions. "
177203 "Use arr.filled() to convert the value to a standard numpy array." )
178204shard_arg_handlers [np .ma .MaskedArray ] = _masked_array_error
179205
180- def _shard_np_array (xs , shardings , layouts ):
206+ def _shard_np_array (xs , shardings , layouts , copy_semantics ):
181207 results = []
182208 for x , sharding , layout in safe_zip (xs , shardings , layouts ):
183209 devices = sharding ._addressable_device_assignment
@@ -197,12 +223,12 @@ def _shard_np_array(xs, shardings, layouts):
197223for _t in array_types :
198224 shard_arg_handlers [_t ] = _shard_np_array
199225
200- def _shard_darray (xs , shardings , layouts ):
201- return shard_args (shardings , layouts , [x ._data for x in xs ])
226+ def _shard_darray (xs , shardings , layouts , copy_semantics ):
227+ return shard_args (shardings , layouts , copy_semantics , [x ._data for x in xs ])
202228shard_arg_handlers [core .DArray ] = _shard_darray
203229
204- def _shard_mutable_array (xs , shardings , layouts ):
205- return shard_args (shardings , layouts , [x ._buf for x in xs ])
230+ def _shard_mutable_array (xs , shardings , layouts , copy_semantics ):
231+ return shard_args (shardings , layouts , copy_semantics , [x ._buf for x in xs ])
206232shard_arg_handlers [core .MutableArray ] = _shard_mutable_array
207233
208234def batched_device_put (aval : core .ShapedArray ,
@@ -1135,7 +1161,8 @@ class InputsHandler:
11351161
11361162 def __init__ (self , in_shardings , in_layouts , local_devices = None ,
11371163 input_indices = None ):
1138- self .handler = partial (shard_args , in_shardings , in_layouts )
1164+ self .handler = partial (shard_args , in_shardings , in_layouts ,
1165+ [None ] * len (in_shardings ))
11391166 self .in_shardings = in_shardings
11401167 self .in_layouts = in_layouts
11411168 self .local_devices = local_devices
@@ -3047,7 +3074,7 @@ def aot_cache_miss(*args, **kwargs):
30473074 JitGlobalCppCacheKeys (), tree_util .dispatch_registry , cc_shard_arg )
30483075
30493076def cc_shard_arg (x , sharding , layout ):
3050- return shard_args ([sharding ], [layout ], [x ])[0 ]
3077+ return shard_args ([sharding ], [layout ], [None ], [ x ])[0 ]
30513078
30523079
30533080def check_arg_avals_for_call (ref_avals , arg_avals ,
0 commit comments