Skip to content

Commit 8b2f4e7

Browse files
Automated Code Change
PiperOrigin-RevId: 872765836
1 parent 757c92d commit 8b2f4e7

File tree

11 files changed

+118
-111
lines changed

11 files changed

+118
-111
lines changed

docs/pallas/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Remember to align the itemized text with the first line of an item within a list
2525
For example, to force the use of the Triton backend you have to now write
2626
`compiler_params=pltriton.CompilerParams()`, where `pltriton` refers to
2727
{mod}`jax.experimental.pallas.triton`.
28+
* Renamed {class}`jax.experimental.pallas.tpu.KernelType` to `CoreType`. The
29+
old name is deprecated and will be removed in a future release.
2830

2931
## Released with JAX 0.9.0
3032

jax/_src/pallas/mosaic/core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
_out_shape_to_aval_mapping = pallas_core._out_shape_to_aval_mapping
4444

4545

46-
class KernelType(enum.Enum):
46+
class CoreType(enum.Enum):
4747
TC = 0
4848
SC_SCALAR_SUBCORE = 1
4949
SC_VECTOR_SUBCORE = 2
@@ -114,7 +114,7 @@ class CompilerParams(pallas_core.CompilerParams):
114114
flags: dict[str, Any] | None = None
115115
internal_scratch_in_bytes: int | None = None
116116
serialization_format: int = 1
117-
kernel_type: KernelType = KernelType.TC
117+
kernel_type: CoreType = CoreType.TC
118118
disable_bounds_checks: bool = False
119119
skip_device_barrier: bool = False
120120
allow_collective_id_without_custom_barrier: bool = False
@@ -131,7 +131,7 @@ def __init__(
131131
flags: Mapping[str, Any] | None = None,
132132
internal_scratch_in_bytes: int | None = None,
133133
serialization_format: int = 1,
134-
kernel_type: KernelType = KernelType.TC,
134+
kernel_type: CoreType = CoreType.TC,
135135
disable_bounds_checks: bool = False,
136136
skip_device_barrier: bool = False,
137137
allow_collective_id_without_custom_barrier: bool = False,
@@ -190,7 +190,7 @@ def __str__(self) -> str:
190190
def from_type(self, ty):
191191
return pallas_core.MemoryRef(ty, memory_space=self)
192192

193-
def __call__(self, shape: Sequence[int], dtype: jnp.dtype):
193+
def __call__(self, shape: Sequence[int], dtype: jnp.dtype[Any]):
194194
# A convenience function for constructing MemoryRef types of ShapedArrays.
195195
return self.from_type(jax_core.ShapedArray(tuple(shape), dtype))
196196

@@ -283,8 +283,8 @@ def __hash__(self) -> int:
283283
)
284284

285285
@property
286-
def kernel_type(self) -> KernelType:
287-
return KernelType.TC
286+
def kernel_type(self) -> CoreType:
287+
return CoreType.TC
288288

289289
@property
290290
def default_memory_space(self) -> pallas_core.MemorySpace:

0 commit comments

Comments
 (0)