2323from jax ._src import config
2424from jax ._src import core
2525from jax ._src import dtypes
26+ from jax ._src import api_util
2627from jax ._src .lax import lax
2728from jax ._src .util import safe_zip , safe_map
2829from jax ._src .typing import Array , ArrayLike , DimSize , DType , DTypeLike , Shape
@@ -213,14 +214,18 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]:
213214@partial (api .jit , inline = True )
214215def _broadcast_arrays (* args : ArrayLike ) -> list [Array ]:
215216 """Like Numpy's broadcast_arrays but doesn't return views."""
216- shapes = [np .shape (arg ) for arg in args ]
217+ avals = [api_util .shaped_abstractify (arg ) for arg in args ]
218+ shapes = [a .shape for a in avals ]
217219 if not shapes or all (core .definitely_equal_shape (shapes [0 ], s ) for s in shapes ):
218220 return [lax .asarray (arg ) for arg in args ]
219221 result_shape = lax .broadcast_shapes (* shapes )
220- return [_broadcast_to (arg , result_shape ) for arg in args ]
222+ result_sharding = (lax .broadcast_shardings (* avals ) # type: ignore
223+ if config .sharding_in_types .value else None )
224+ return [_broadcast_to (arg , result_shape , result_sharding ) for arg in args ]
221225
222226
223- def _broadcast_to (arr : ArrayLike , shape : DimSize | Shape ) -> Array :
227+ def _broadcast_to (arr : ArrayLike , shape : DimSize | Shape , sharding = None
228+ ) -> Array :
224229 check_arraylike ("broadcast_to" , arr )
225230 arr = arr if isinstance (arr , Array ) else lax .asarray (arr )
226231 if not isinstance (shape , tuple ) and np .ndim (shape ) == 0 :
@@ -240,7 +245,8 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
240245 if nlead < 0 or not compatible :
241246 msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
242247 raise ValueError (msg .format (arr_shape , shape ))
243- return lax .broadcast_in_dim (arr , shape , tuple (range (nlead , len (shape ))))
248+ return lax .broadcast_in_dim (arr , shape , tuple (range (nlead , len (shape ))),
249+ sharding = sharding )
244250
245251
246252# The `jit` on `where` exists to avoid materializing constants in cases like
0 commit comments