@@ -1866,3 +1866,108 @@ def subf(a: ir.Value, b: ir.Value):
18661866
18671867def mulf (a : ir .Value , b : ir .Value ):
18681868 return arith .mulf (a , b , fastmath = arith .FastMathFlags .contract )
1869+
1870+
1871+ def optimization_barrier (* arrays : mgpu .FragmentedArray ):
1872+ """Acts as an optimization barrier for LLVM.
1873+
1874+ Passing arrays through this function will make sure that they are computed
1875+ before any side-effecting operations that follow this barrier.
1876+ """
1877+ index = ir .IndexType .get ()
1878+ i32 = ir .IntegerType .get_signless (32 )
1879+
1880+ regs = []
1881+ reg_dtypes = []
1882+ reg_constraints = []
1883+ ptx_lines = ["// Optimization barrier" ]
1884+ repack_fns = []
1885+ # We unpack each array into a flat list of registers, and prepare the
1886+ # functions that invert the transform in repack_fns.
1887+ for array in arrays :
1888+ ptx_lines .append ("// Next array" )
1889+ reg_ty = array .registers .flat [0 ].type
1890+ dtype = array .mlir_dtype
1891+ num_prev_cstr = len (reg_constraints )
1892+ if ir .F32Type .isinstance (dtype ):
1893+ if ir .VectorType .isinstance (reg_ty ):
1894+ [vec_len ] = ir .VectorType (reg_ty ).shape
1895+ array_regs = [ # pylint: disable=g-complex-comprehension
1896+ vector .extractelement (reg , position = c (pos , index ))
1897+ for reg in array .registers .flat
1898+ for pos in range (vec_len )
1899+ ]
1900+ def _repack (regs , reg_ty = reg_ty ):
1901+ reg = llvm .mlir_undef (reg_ty )
1902+ [vec_len ] = ir .VectorType (reg_ty ).shape
1903+ for i_elem in range (vec_len ):
1904+ reg = llvm .insertelement (
1905+ reg , next (regs ), arith .constant (i32 , i_elem )
1906+ )
1907+ return reg
1908+ repack_fns .append (_repack )
1909+ else :
1910+ array_regs = list (array .registers .flat )
1911+ repack_fns .append (lambda regs : next (regs ))
1912+ reg_constraint = "f"
1913+ elif ir .BF16Type .isinstance (dtype ) or ir .F16Type .isinstance (dtype ):
1914+ if not ir .VectorType .isinstance (reg_ty ):
1915+ raise NotImplementedError (array .mlir_dtype )
1916+ [vec_len ] = ir .VectorType (reg_ty ).shape
1917+ if vec_len != 2 :
1918+ raise NotImplementedError (vec_len )
1919+ i32_reg_ty = ir .VectorType .get ((1 ,), i32 )
1920+ array_regs = [
1921+ vector .extractelement (
1922+ vector .bitcast (i32_reg_ty , reg ), position = c (0 , index )
1923+ )
1924+ for reg in array .registers .flat
1925+ ]
1926+ reg_constraint = "r"
1927+ def _repack (regs , reg_ty = reg_ty , i32_reg_ty = i32_reg_ty ):
1928+ return vector .bitcast (reg_ty , vector .splat (i32_reg_ty , next (regs )))
1929+ repack_fns .append (_repack )
1930+ else :
1931+ raise NotImplementedError (array .mlir_dtype )
1932+ regs += array_regs
1933+ reg_dtypes += [array_regs [0 ].type ] * len (array_regs )
1934+ reg_constraints += [f"={ reg_constraint } " ] * len (array_regs )
1935+ reg_constraints += [reg_constraint ] * len (array_regs )
1936+ ptx_lines += [
1937+ f"mov.b32 ${ i } , ${ len (array_regs )+ i } "
1938+ for i in range (num_prev_cstr , num_prev_cstr + len (array_regs ))
1939+ ]
1940+ reg_constraints = "," .join (reg_constraints )
1941+ ptx = ";\n \t " .join (ptx_lines ) + ";"
1942+ struct_ty = ir .Type .parse (
1943+ f"!llvm.struct<({ ',' .join (map (str , reg_dtypes ))} )>"
1944+ )
1945+ result_struct = llvm .inline_asm (
1946+ struct_ty , regs , ptx , reg_constraints ,
1947+ asm_dialect = 0 , has_side_effects = True ,
1948+ )
1949+ regs = [
1950+ llvm .extractvalue (dtype , result_struct , [i ])
1951+ for i , dtype in enumerate (reg_dtypes )
1952+ ]
1953+ i32 = ir .IntegerType .get_signless (32 )
1954+ results = []
1955+ regs_it = iter (regs )
1956+ for array , repack_fn in zip (arrays , repack_fns , strict = True ):
1957+ num_regs = array .registers .size
1958+ reg_ty = array .registers .flat [0 ].type
1959+ if ir .VectorType .isinstance (reg_ty ):
1960+ reg_ty = ir .VectorType (reg_ty )
1961+ new_registers = np .empty ((num_regs ,), dtype = object )
1962+ for i_vreg in range (num_regs ):
1963+ reg = repack_fn (regs_it )
1964+ assert reg .type == reg_ty , (reg .type , reg_ty )
1965+ new_registers [i_vreg ] = reg
1966+ results .append (
1967+ FragmentedArray (
1968+ _registers = new_registers .reshape (array .registers .shape ),
1969+ _layout = array .layout ,
1970+ _is_signed = array .is_signed ,
1971+ )
1972+ )
1973+ return results [0 ] if len (arrays ) == 1 else results
0 commit comments