@@ -1652,8 +1652,8 @@ def _reshape_lowering_rule(
16521652 )
16531653
16541654
1655- def _compute_offsets_from_indices (
1656- block_info : BlockInfo , nd_indexer : NDIndexer
1655+ def _compute_pointers_from_indices (
1656+ root_ptr : ir . Value , block_info : BlockInfo , nd_indexer : NDIndexer
16571657) -> ir .Value :
16581658 full_shape = block_info .full_shape_dtype .shape
16591659 num_mapped_dims = sum (b is pallas_core .mapped for b in block_info .block_shape )
@@ -1732,14 +1732,7 @@ def _compute_offsets_from_indices(
17321732 dim_offsets = _mul (dim_offsets , _full (dim_offsets .type , dim_stride ))
17331733 offsets = _add (offsets , dim_offsets )
17341734
1735- return offsets
1736-
1737-
1738- def _compute_pointers_from_indices (
1739- root_ptr : ir .Value , block_info : BlockInfo , nd_indexer : NDIndexer
1740- ) -> ir .Value :
1741- offsets = _compute_offsets_from_indices (block_info , nd_indexer )
1742- return _add (_bcast_to (root_ptr , nd_indexer .get_indexer_shape ()), offsets )
1735+ return _add (_bcast_to (root_ptr , indexer_shape ), offsets )
17431736
17441737
17451738@register_lowering (sp .get_p )
@@ -1855,20 +1848,14 @@ def _masked_load_lowering_rule(
18551848 if not tt_dialect .PointerType .isinstance (ptr .type ):
18561849 assert len (ctx .avals_in ) == 1
18571850 return ptr
1858-
1859- offsets = _compute_offsets_from_indices (block_info , idx )
1860- ptr_offsets = offsets
1861-
1862- if block_info .full_shape_dtype .dtype in (jnp .int4 , jnp .uint4 ):
1863- ptr_offsets = _floordiv (offsets , _full (offsets .type , 2 ), signed = False )
1864-
1865- shape = idx .get_indexer_shape ()
1866- ptr = _add (_bcast_to (ptr , shape ), ptr_offsets )
1851+ ptr = _compute_pointers_from_indices (ptr , block_info , idx )
18671852 if mask is not None :
1868- mask = _bcast_to (_ensure_ir_value (mask , mask_aval ), shape )
1853+ mask = _bcast_to (_ensure_ir_value (mask , mask_aval ), idx . get_indexer_shape () )
18691854 if other is not None :
1870- other = _bcast_to (_ensure_ir_value (other , other_aval ), shape )
1871- values = _load (
1855+ other = _bcast_to (
1856+ _ensure_ir_value (other , other_aval ), idx .get_indexer_shape ()
1857+ )
1858+ return _load (
18721859 ptr ,
18731860 mask = mask ,
18741861 other = other ,
@@ -1877,19 +1864,6 @@ def _masked_load_lowering_rule(
18771864 eviction_policy = eviction_policy ,
18781865 )
18791866
1880- if block_info .full_shape_dtype .dtype not in (jnp .int4 , jnp .uint4 ):
1881- return values
1882-
1883- # XLA packs pairs of `[u]int4` values into a `uint8` value with the first
1884- # in the most significant bits and the second in the least significant.
1885- offsets = _ir_cast (offsets , ir .IntegerType .get_signless (32 ), signed = False )
1886- in_lsb = _mod (offsets , _full (offsets .type , 2 ), signed = False )
1887- in_msb = arith_dialect .xori (in_lsb , _full (in_lsb .type , 1 ))
1888- shift = _mul (in_msb , _full (in_msb .type , 4 ))
1889- shift = _ir_cast (shift , values .type , signed = False )
1890- values = arith_dialect .shrui (values , shift )
1891- return _ir_cast (values , ir .IntegerType .get_signless (4 ), signed = False )
1892-
18931867
18941868@register_lowering (sp .swap_p )
18951869def _swap_lowering_rule (ctx : LoweringRuleContext , ptr , value , * idx , tree ):
0 commit comments