Skip to content

Commit c1a60c6

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas] Add empty/empty_like helper functions
PiperOrigin-RevId: 713344151
1 parent 5511949 commit c1a60c6

File tree

4 files changed

+55
-1
lines changed

4 files changed

+55
-1
lines changed

jax/_src/pallas/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ py_library(
3131
"__init__.py",
3232
"core.py",
3333
"cost_estimate.py",
34+
"helpers.py",
3435
"pallas_call.py",
3536
"primitives.py",
3637
"utils.py",

jax/_src/pallas/helpers.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Pallas helper functions."""
15+
16+
from typing import Any, Protocol
17+
18+
import jax
19+
import jax.numpy as jnp
20+
from jax._src.pallas import pallas_call
21+
from jax._src.pallas import core as pl_core
22+
23+
24+
@jax.named_call
25+
def empty(
26+
shape: tuple[int, ...], dtype: jnp.dtype, *, memory_space: Any = None
27+
):
28+
def _empty_kernel(_):
29+
# No-op to leave the out_ref uninitialized
30+
pass
31+
32+
if memory_space is None:
33+
kernel_memory_space = pl_core.MemorySpace.ANY
34+
memory_space = jax.ShapeDtypeStruct
35+
else:
36+
kernel_memory_space = memory_space
37+
return pallas_call.pallas_call(
38+
_empty_kernel,
39+
in_specs=[],
40+
out_specs=pl_core.BlockSpec(memory_space=kernel_memory_space),
41+
out_shape=memory_space(shape, dtype),
42+
)()
43+
44+
45+
class ArrayLike(Protocol):
46+
shape: tuple[int, ...]
47+
dtype: jnp.dtype
48+
49+
50+
def empty_like(x: ArrayLike, *, memory_space: Any = None):
51+
return empty(x.shape, x.dtype, memory_space=memory_space)

jax/_src/pallas/mosaic/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def _tensorcore_mesh_discharge_rule(
242242
*args,
243243
mesh,
244244
jaxpr,
245-
compiler_params: TPUCompilerParams,
245+
compiler_params: Any | None,
246246
interpret: bool,
247247
debug: bool,
248248
cost_estimate: pallas_core.CostEstimate | None,

jax/experimental/pallas/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from jax._src.pallas.core import Unblocked as Unblocked
3232
from jax._src.pallas.core import unblocked as unblocked
3333
from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost
34+
from jax._src.pallas.helpers import empty as empty
35+
from jax._src.pallas.helpers import empty_like as empty_like
3436
from jax._src.pallas.pallas_call import pallas_call as pallas_call
3537
from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p
3638
from jax._src.pallas.primitives import atomic_add as atomic_add

0 commit comments

Comments
 (0)