File tree Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -183,9 +183,9 @@ def __str__(self) -> str:
183183 def from_type (self , ty ):
184184 return pallas_core .MemoryRef (ty , memory_space = self )
185185
186- def __call__ (self , shape : tuple [int , ... ], dtype : jnp .dtype ):
186+ def __call__ (self , shape : Sequence [int ], dtype : jnp .dtype ):
187187 # A convenience function for constructing MemoryRef types of ShapedArrays.
188- return self .from_type (jax_core .ShapedArray (shape , dtype ))
188+ return self .from_type (jax_core .ShapedArray (tuple ( shape ) , dtype ))
189189
190190class dma_semaphore (pallas_core .semaphore_dtype ): pass
191191
Original file line number Diff line number Diff line change @@ -142,14 +142,15 @@ def __str__(self) -> str:
142142
143143 def __call__ (
144144 self ,
145- shape : tuple [int , ... ],
145+ shape : Sequence [int ],
146146 dtype : jnp .dtype ,
147147 * ,
148148 transforms : Sequence [MemoryRefTransform ] = (),
149149 packed : bool | None = None ,
150150 collective : bool | None = None ,
151151 layout : TMEMLayout | None = None ,
152152 ) -> pallas_core .MemoryRef :
153+ shape = tuple (shape )
153154 # TODO(sharadmv): Add HiType constructor support.
154155 if self == MemorySpace .TMEM :
155156 if transforms :
You can’t perform that action at this time.
0 commit comments