@@ -706,17 +706,33 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
706706@lowering .register_lowering_rule (broadcasted_iota_p )
707707def _broadcasted_iota_lowering (ctx : lowering .LoweringRuleContext , dtype , shape , dimension , layout ):
708708 del ctx
709- undef = llvm_dialect .mlir_undef (mlir .dtype_to_ir_type (dtype ))
709+ # Unsigned integers (as opposed to signless) cause MLIR verification
710+ # errors so we only use signless like Mosaic GPU does.
711+ #
712+ # TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead.
713+ mlir_dtype = (
714+ ir .IntegerType .get_signless (dtype .itemsize * 8 )
715+ if jnp .issubdtype (dtype , jnp .integer )
716+ else mlir .dtype_to_ir_type (dtype )
717+ )
718+ undef = llvm_dialect .mlir_undef (mlir_dtype )
710719 is_signed = (
711720 jnp .issubdtype (dtype , jnp .signedinteger )
712721 if jnp .issubdtype (dtype , jnp .integer )
713722 else None
714723 )
715- mlir_dtype = mlir .dtype_to_ir_type (dtype )
724+
725+ i32 = ir .IntegerType .get_signless (32 )
726+ def _cast (x ):
727+ if ir .FloatType .isinstance (mlir_dtype ):
728+ x = arith_dialect .index_cast (i32 , x )
729+ return arith_dialect .uitofp (mlir_dtype , x )
730+ else :
731+ return arith_dialect .index_cast (mlir_dtype , x )
716732 return mgpu .FragmentedArray .splat (
717733 undef , shape , layout .value , is_signed = is_signed
718734 ).foreach (
719- lambda _ , idx : arith_dialect . index_cast ( mlir_dtype , idx [dimension ]), create_array = True , is_signed = is_signed
735+ lambda _ , idx : _cast ( idx [dimension ]), create_array = True , is_signed = is_signed
720736 )
721737
722738
0 commit comments