@@ -11,6 +11,16 @@ using namespace mlir::triton;
1111using namespace mlir ::triton::gpu;
1212using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
1313namespace {
14+
15+ Value bitOrPtrCast (Value val, Type type, TritonLLVMOpBuilder &b) {
16+ if (isa<LLVM::LLVMPointerType>(val.getType ()) &&
17+ !isa<LLVM::LLVMPointerType>(type)) {
18+ return b.ptrtoint (type, val);
19+ } else {
20+ return b.bitcast (val, type);
21+ }
22+ }
23+
1424struct SplatOpConversion : public ConvertOpToLLVMPattern <triton::SplatOp> {
1525 using ConvertOpToLLVMPattern<triton::SplatOp>::ConvertOpToLLVMPattern;
1626 // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
@@ -39,13 +49,13 @@ struct SplatOpConversion : public ConvertOpToLLVMPattern<triton::SplatOp> {
3949 unsigned ratio = srcBitWidth / cstBitWidth;
4050 Type intTy = IntegerType::get (elemType.getContext (), cstBitWidth);
4151 VectorType vecType = VectorType::get (ratio, intTy);
42- Value intCst = b. bitcast (constVal, intTy);
52+ Value intCst = bitOrPtrCast (constVal, intTy, b );
4353 Value vec = b.undef (vecType);
4454 for (unsigned i = 0 ; i < ratio; ++i)
4555 vec = b.insert_element (vecType, vec, intCst, b.int_val (32 , i));
4656 constVal = vec;
4757 }
48- auto llSrc = b. bitcast (constVal, srcType);
58+ Value llSrc = bitOrPtrCast (constVal, srcType, b );
4959 size_t elemsPerThread = getTotalElemsPerThread (tensorTy);
5060 llvm::SmallVector<Value> elems (elemsPerThread, llSrc);
5161 return packLLElements (loc, typeConverter, elems, rewriter, resType);
0 commit comments