@@ -87,15 +87,15 @@ class ModuleContext:
8787class BlockInfo :
8888 full_shape_dtype : jax .ShapeDtypeStruct
8989 start_indices : Sequence [Any ]
90- block_shape : tuple [int , ...] # TODO(necula): can this contain "mapped"?
90+ block_shape : tuple [int | pallas_core . Mapped , ...]
9191
9292
9393@dataclasses .dataclass
9494class LoweringRuleContext :
9595 context : ModuleContext
9696 avals_in : Sequence [jax_core .ShapedArray ]
9797 avals_out : Sequence [jax_core .ShapedArray ]
98- block_infos : Sequence [BlockInfo | None ] # TODO(necula): can this be None?
98+ block_infos : Sequence [BlockInfo | None ]
9999
100100 replace = dataclasses .replace
101101
@@ -362,14 +362,15 @@ def read_env(atom: jax_core.Atom):
362362 def read_block_info_env (atom : jax_core .Atom ):
363363 if isinstance (atom , jax_core .Literal ):
364364 return None
365- return block_info_env .get (atom , None )
365+ return block_info_env .get (atom )
366366
367367 def write_env (var : jax_core .Var , val ):
368368 env [var ] = val
369369
370370 if block_infos is not None :
371371 for invar , block_info in zip (jaxpr .invars , block_infos ):
372- block_info_env [invar ] = block_info
372+ if block_info is not None :
373+ block_info_env [invar ] = block_info
373374
374375 map (write_env , jaxpr .invars , args )
375376
@@ -393,7 +394,7 @@ def write_env(var: jax_core.Var, val):
393394 raise # We only add the extra info to the innermost exception.
394395 except Exception as e :
395396 if not pallas_call ._verbose_errors_enabled ():
396- raise
397+ raise
397398 inval_types = map (lambda t : getattr (t , "type" , None ), invals )
398399 raise LoweringError (
399400 f"Exception while lowering eqn:\n { eqn } \n With context:\n "
@@ -474,14 +475,14 @@ def _atomic_lowering_rule(
474475 args_tree ,
475476 atomic_type : primitives .AtomicOpType ,
476477):
478+ block_info , * _ = ctx .block_infos
479+ assert block_info is not None
477480 ptr , indexers , val , mask = args_tree .unflatten (args_flat )
478481 * _ , value_aval , mask_aval = args_tree .unflatten (ctx .avals_in )
479482 if len (indexers ) != 1 :
480483 raise NotImplementedError ("Only single indexer is supported." )
481484 idx = indexers [0 ]
482- ptr = _compute_pointers_from_indices (
483- ptr , ctx .block_infos [0 ], idx , ctx .avals_in [0 ]
484- )
485+ ptr = _compute_pointers_from_indices (ptr , block_info , idx )
485486 val = _ensure_ir_value (val , value_aval )
486487 if mask is not None :
487488 mask = _ensure_ir_value (mask , mask_aval )
@@ -1631,36 +1632,21 @@ def _reshape_lowering_rule(
16311632
16321633
16331634def _compute_pointers_from_indices (
1634- root_ptr : ir .Value ,
1635- block_info : BlockInfo | None ,
1636- nd_indexer : NDIndexer ,
1637- array_shape_dtype : Any ,
1635+ root_ptr : ir .Value , block_info : BlockInfo , nd_indexer : NDIndexer
16381636) -> ir .Value :
1639- if block_info is None : # TODO(necula): is this branch dead?
1640- full_shape = array_shape_dtype .shape
1641- num_mapped_dims = 0
1642- block_shape = array_shape_dtype .shape
1643- else :
1644- full_shape = block_info .full_shape_dtype .shape
1645- num_mapped_dims = sum (
1646- b is pallas_core .mapped for b in block_info .block_shape
1647- )
1648- block_shape = block_info .block_shape
1637+ full_shape = block_info .full_shape_dtype .shape
1638+ num_mapped_dims = sum (b is pallas_core .mapped for b in block_info .block_shape )
16491639 strides = pallas_utils .strides_from_shape (full_shape )
16501640 indexer_shape = nd_indexer .get_indexer_shape ()
16511641 int_indexer_shape = nd_indexer .int_indexer_shape
16521642 _check_tensor_size (indexer_shape )
16531643 indices = nd_indexer .indices
16541644 other_shape = indexer_shape [len (int_indexer_shape ) :]
16551645 other_shape_idx = 0
1656- if block_info is None :
1657- start_index_offsets = [None ] * len (indices )
1658- else :
1659- start_index_offsets = block_info .start_indices
16601646 assert len (indices ) + num_mapped_dims == len (full_shape )
1661- assert len (start_index_offsets ) == len (full_shape )
1647+ assert len (block_info . start_indices ) == len (full_shape )
16621648
1663- array_dtype = jnp .dtype (array_shape_dtype .dtype )
1649+ array_dtype = jnp .dtype (block_info . full_shape_dtype .dtype )
16641650 full_size = math .prod (full_shape ) * array_dtype .itemsize
16651651 # Use 64-bit indexing when offset might be >= 2**32 bytes.
16661652 offset_eltype = ir .IntegerType .get_signless (64 if full_size > 2 ** 32 else 32 )
@@ -1671,7 +1657,7 @@ def _compute_pointers_from_indices(
16711657
16721658 indexer_iter = iter (indices )
16731659 for dim_stride , dim_block_size , start_offset in zip (
1674- strides , block_shape , start_index_offsets
1660+ strides , block_info . block_shape , block_info . start_indices
16751661 ):
16761662 if dim_block_size is pallas_core .mapped :
16771663 index = _ir_constant (0 , offset_eltype )
@@ -1831,6 +1817,8 @@ def _masked_load_lowering_rule(
18311817 cache_modifier ,
18321818 is_volatile ,
18331819):
1820+ block_info , * _ = ctx .block_infos
1821+ assert block_info is not None
18341822 ptr , indexers , mask , other = args_tree .unflatten (args_flat )
18351823 * _ , mask_aval , other_aval = args_tree .unflatten (ctx .avals_in )
18361824 if len (indexers ) > 1 :
@@ -1839,9 +1827,7 @@ def _masked_load_lowering_rule(
18391827 if not tt_dialect .PointerType .isinstance (ptr .type ):
18401828 assert len (ctx .avals_in ) == 1
18411829 return ptr
1842- ptr = _compute_pointers_from_indices (
1843- ptr , ctx .block_infos [0 ], idx , ctx .avals_in [0 ]
1844- )
1830+ ptr = _compute_pointers_from_indices (ptr , block_info , idx )
18451831 if mask is not None :
18461832 mask = _bcast_to (_ensure_ir_value (mask , mask_aval ), idx .get_indexer_shape ())
18471833 if other is not None :
@@ -1931,14 +1917,14 @@ def _store(
19311917def _masked_swap_lowering_rule (
19321918 ctx : LoweringRuleContext , * args_flat , args_tree , eviction_policy
19331919):
1920+ block_info , * _ = ctx .block_infos
1921+ assert block_info is not None
19341922 ptr , indexers , value , mask = args_tree .unflatten (args_flat )
19351923 * _ , value_aval , mask_aval = args_tree .unflatten (ctx .avals_in )
19361924 if len (indexers ) > 1 :
19371925 raise NotImplementedError ("No support for multiple indexers yet." )
19381926 idx = indexers [0 ]
1939- ptr = _compute_pointers_from_indices (
1940- ptr , ctx .block_infos [0 ], idx , ctx .avals_in [0 ]
1941- )
1927+ ptr = _compute_pointers_from_indices (ptr , block_info , idx )
19421928 other = None
19431929 if value is not None :
19441930 value = _ensure_ir_value (value , value_aval )
@@ -1954,16 +1940,16 @@ def _masked_swap_lowering_rule(
19541940
19551941@register_lowering (sp .addupdate_p )
19561942def _addupdate_lowering_rule (ctx : LoweringRuleContext , ptr , value , * idx , tree ):
1943+ block_info , * _ = ctx .block_infos
1944+ assert block_info is not None
19571945 indexers = tree_util .tree_unflatten (tree , idx )
19581946 if not tt_dialect .PointerType .isinstance (ptr .type ):
19591947 assert len (indexers ) == 0
19601948 return ptr
19611949 if len (indexers ) > 1 :
19621950 raise NotImplementedError ("No support for multiple indexers yet." )
19631951 indexer = indexers [0 ]
1964- ptr = _compute_pointers_from_indices (
1965- ptr , ctx .block_infos [0 ], indexer , ctx .avals_in [0 ]
1966- )
1952+ ptr = _compute_pointers_from_indices (ptr , block_info , indexer )
19671953 op = tt_dialect .RMWOp .FADD
19681954 if isinstance (_element_type (value .type ), ir .IntegerType ):
19691955 op = tt_dialect .RMWOp .ADD
0 commit comments