Skip to content

Commit 93c5748

Browse files
brianwa84Google-ML-Automation
authored andcommitted
Tiny usability tweak to allow user to pass a list instead of a tuple, e.g. pltpu.VMEM([8], jnp.int32).
PiperOrigin-RevId: 836134234
1 parent 6188dd2 commit 93c5748

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

jax/_src/pallas/mosaic/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ def __str__(self) -> str:
183183
def from_type(self, ty):
184184
return pallas_core.MemoryRef(ty, memory_space=self)
185185

186-
def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
186+
def __call__(self, shape: Sequence[int], dtype: jnp.dtype):
187187
# A convenience function for constructing MemoryRef types of ShapedArrays.
188-
return self.from_type(jax_core.ShapedArray(shape, dtype))
188+
return self.from_type(jax_core.ShapedArray(tuple(shape), dtype))
189189

190190
class dma_semaphore(pallas_core.semaphore_dtype): pass
191191

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,15 @@ def __str__(self) -> str:
142142

143143
def __call__(
144144
self,
145-
shape: tuple[int, ...],
145+
shape: Sequence[int],
146146
dtype: jnp.dtype,
147147
*,
148148
transforms: Sequence[MemoryRefTransform] = (),
149149
packed: bool | None = None,
150150
collective: bool | None = None,
151151
layout: TMEMLayout | None = None,
152152
) -> pallas_core.MemoryRef:
153+
shape = tuple(shape)
153154
# TODO(sharadmv): Add HiType constructor support.
154155
if self == MemorySpace.TMEM:
155156
if transforms:

0 commit comments

Comments
 (0)