diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 5a7897f233eaa..4100b086fad8b 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -91,6 +91,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { Type llvmI32 = this->typeConverter->convertType(i32); Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type()); + auto toI32 = [&](Value val) -> Value { + if (val.getType() == llvmI32) + return val; + + return rewriter.create(loc, llvmI32, val); + }; + int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8; Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); @@ -166,22 +173,22 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { Value stride = rewriter.create( loc, llvmI16, rewriter.getI16IntegerAttr(0)); Value numRecords; - if (memrefType.hasStaticShape()) { + if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) { numRecords = createI32Constant( rewriter, loc, static_cast(memrefType.getNumElements() * elementByteWidth)); } else { Value maxIndex; for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { - Value size = memrefDescriptor.size(rewriter, loc, i); - Value stride = memrefDescriptor.stride(rewriter, loc, i); + Value size = toI32(memrefDescriptor.size(rewriter, loc, i)); + Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i)); stride = rewriter.create(loc, stride, byteWidthConst); Value maxThisDim = rewriter.create(loc, size, stride); maxIndex = maxIndex ? rewriter.create(loc, maxIndex, maxThisDim) : maxThisDim; } - numRecords = rewriter.create(loc, llvmI32, maxIndex); + numRecords = maxIndex; } // Flag word: @@ -218,7 +225,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { Value strideOp; if (ShapedType::isDynamic(strides[i])) { strideOp = rewriter.create( - loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst); + loc, toI32(memrefDescriptor.stride(rewriter, loc, i)), + byteWidthConst); } else { strideOp = createI32Constant(rewriter, loc, strides[i] * elementByteWidth); @@ -240,7 +248,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { sgprOffset = createI32Constant(rewriter, loc, 0); if (ShapedType::isDynamic(offset)) sgprOffset = rewriter.create( - loc, memrefDescriptor.offset(rewriter, loc), sgprOffset); + loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset); else if (offset > 0) sgprOffset = rewriter.create( loc, sgprOffset, createI32Constant(rewriter, loc, offset)); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index a9ea44925e914..4c7515dc81051 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -30,6 +30,25 @@ func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 { func.return %0 : i32 } +// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided +func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<64xi32, strided<[?], offset: ?>>, %idx: i32) -> i32 { + // CHECK-DAG: %[[rstride:.*]] = llvm.mlir.constant(0 : i16) + // CHECK-DAG: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32) + // CHECK: %[[size:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[size32:.*]] = llvm.trunc %[[size]] : i64 to i32 + // CHECK: %[[stride:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[stride32:.*]] = llvm.trunc %[[stride]] : i64 to i32 + // CHECK: %[[tmp:.*]] = llvm.mul %[[stride32]], %[[elem_size]] : i32 + // CHECK: %[[numRecords:.*]] = llvm.mul %[[size32]], %[[tmp]] : i32 + // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32) + // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32) + // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[rstride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8> + // CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32 + // CHECK: return %[[ret]] + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32, strided<[?], offset: ?>>, i32 -> i32 + func.return %0 : i32 +} + // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_oob_off func.func @gpu_gcn_raw_buffer_load_i32_oob_off(%buf: memref<64xi32>, %idx: i32) -> i32 { // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)