4343_out_shape_to_aval_mapping = pallas_core ._out_shape_to_aval_mapping
4444
4545
46- class KernelType (enum .Enum ):
46+ class CoreType (enum .Enum ):
4747 TC = 0
4848 SC_SCALAR_SUBCORE = 1
4949 SC_VECTOR_SUBCORE = 2
@@ -114,7 +114,7 @@ class CompilerParams(pallas_core.CompilerParams):
114114 flags : dict [str , Any ] | None = None
115115 internal_scratch_in_bytes : int | None = None
116116 serialization_format : int = 1
117- kernel_type : KernelType = KernelType .TC
117+ kernel_type : CoreType = CoreType .TC
118118 disable_bounds_checks : bool = False
119119 skip_device_barrier : bool = False
120120 allow_collective_id_without_custom_barrier : bool = False
@@ -131,7 +131,7 @@ def __init__(
131131 flags : Mapping [str , Any ] | None = None ,
132132 internal_scratch_in_bytes : int | None = None ,
133133 serialization_format : int = 1 ,
134- kernel_type : KernelType = KernelType .TC ,
134+ kernel_type : CoreType = CoreType .TC ,
135135 disable_bounds_checks : bool = False ,
136136 skip_device_barrier : bool = False ,
137137 allow_collective_id_without_custom_barrier : bool = False ,
@@ -190,7 +190,7 @@ def __str__(self) -> str:
190190 def from_type (self , ty ):
191191 return pallas_core .MemoryRef (ty , memory_space = self )
192192
193- def __call__ (self , shape : Sequence [int ], dtype : jnp .dtype ):
193+ def __call__ (self , shape : Sequence [int ], dtype : jnp .dtype [ Any ] ):
194194 # A convenience function for constructing MemoryRef types of ShapedArrays.
195195 return self .from_type (jax_core .ShapedArray (tuple (shape ), dtype ))
196196
@@ -283,8 +283,8 @@ def __hash__(self) -> int:
283283 )
284284
285285 @property
286- def kernel_type (self ) -> KernelType :
287- return KernelType .TC
286+ def kernel_type (self ) -> CoreType :
287+ return CoreType .TC
288288
289289 @property
290290 def default_memory_space (self ) -> pallas_core .MemorySpace :
0 commit comments