diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index d58558ac32884..eadb5d9326798 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -898,6 +898,40 @@ def AMDGPU_GatherToLDSOp : let hasVerifier = 1; } +def AMDGPU_TransposeLoadOp : + AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>, + Arguments<(ins Arg:$src, Variadic:$srcIndices)>, + Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> { + let summary = "MLIR wrapper for CDNA Transpose Load instructions"; + let description = [{ + The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions. + The transpose load op represents a subgroup load from LDS memory, + where the subgroup of threads collectively reads a matrix from the source + memref, with each thread reading a vector of the matrix, and gets a transposed matrix + in as the result. That is, each thread reads a vector of the col-major matrix at different + indices, and the thread's read result is a vector of the corresponding row of the transposed + matrix. + + This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer + to the CDNA4 ISA documentation for more details about its exact semantics. + + Format example: + ``` + %0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16> + ``` + Operands: + * `$src`: LDS memref to read from. + * `$srcIndices`: indices into `$src` to read from for this thread. + * `$result`: target register this transpose load instruction will write to. + + Note: Lowering is only supported on gfx950 and up. + }]; + let assemblyFormat = [{ + $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result) + }]; + let hasVerifier = 1; +} + def AMDGPU_ScaledMFMAOp : AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>, Pure]>, diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 700563460f525..910fe1b1d93c1 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1100,6 +1100,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { } }; +struct TransposeLoadOpLowering + : public ConvertOpToLLVMPattern { + TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset != kGfx950) + return op.emitOpError("Non-gfx950 chipset not supported"); + + Location loc = op.getLoc(); + auto srcMemRefType = cast(op.getSrc().getType()); + + // Elements in subbyte memrefs are stored non-contiguously, + // reject if source is sub-byte memref. Use emulated memrefs instead. + size_t srcElementSize = + srcMemRefType.getElementType().getIntOrFloatBitWidth(); + if (srcElementSize < 8) + return op.emitOpError("Expect source memref to have at least 8 bits " + "element size, got ") + << srcElementSize; + + auto resultType = cast(op.getResult().getType()); + Value srcPtr = + getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(), + (adaptor.getSrcIndices())); + + size_t numElements = resultType.getNumElements(); + size_t elementTypeSize = + resultType.getElementType().getIntOrFloatBitWidth(); + + // ROCDL transpose load intrinsics return vectors of 32-bit integers, if + // the element size is smaller than 16 bits. + Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32, + rewriter.getIntegerType(32)); + Type llvmResultType = typeConverter->convertType(resultType); + + switch (elementTypeSize) { + case 4: { + assert(numElements == 16); + auto rocdlOp = + rewriter.create(loc, rocdlResultType, srcPtr); + rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); + break; + } + case 6: { + assert(numElements == 16); + auto rocdlOp = + rewriter.create(loc, rocdlResultType, srcPtr); + rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); + break; + } + case 8: { + assert(numElements == 8); + auto rocdlOp = + rewriter.create(loc, rocdlResultType, srcPtr); + rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); + break; + } + case 16: { + assert(numElements == 4); + rewriter.replaceOpWithNewOp(op, llvmResultType, + srcPtr); + break; + } + default: + return op.emitOpError("Unsupported element size for transpose load"); + } + return success(); + } +}; + struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern { GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} @@ -1749,7 +1824,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, - PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter, - chipset); + PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, + TransposeLoadOpLowering>(converter, chipset); patterns.add(converter); } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 0d0add3094666..4613d14461969 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" #include @@ -524,6 +525,39 @@ LogicalResult GatherToLDSOp::verify() { return success(); } +LogicalResult TransposeLoadOp::verify() { + MemRefType srcType = cast(getSrc().getType()); + + if (!hasWorkgroupMemorySpace(srcType.getMemorySpace())) + return emitOpError("source memory address space must be Workgroup"); + + auto transferType = cast(getType()); + size_t numElements = transferType.getNumElements(); + size_t elementTypeSize = + transferType.getElementType().getIntOrFloatBitWidth(); + + // ElementSize -> NumElements + const llvm::SmallDenseMap KValidLoadSizeMap = { + {4, 16}, + {6, 16}, + {8, 8}, + {16, 4}, + }; + + auto validNumElems = KValidLoadSizeMap.find(elementTypeSize); + if (validNumElems == KValidLoadSizeMap.end()) { + return emitOpError("Unsupported element type size for transpose load: ") + << elementTypeSize << " bits"; + } + if (numElements != validNumElems->second) { + return emitOpError( + "Transferring type size mismatch: expected num of elements: ") + << validNumElems->second; + } + + return success(); +} + #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" #define GET_ATTRDEF_CLASSES diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir new file mode 100644 index 0000000000000..68799098f1d36 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s +// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD + +// CHECK-LABEL: func @transpose_load_to_rocdl_4xf16 +func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, 3>) -> vector<4xf16> { + // CHECK: rocdl.ds.read.tr16.b64 + // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported + %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, 3> -> vector<4xf16> + return %0 : vector<4xf16> +} + +// ----- + +// CHECK-LABEL: func @transpose_load_to_rocdl_8xi8 +func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, 3>) -> vector<8xi8> { + // CHECK: %[[RES:.*]] = rocdl.ds.read.tr8.b64 + // CHECK-SAME: -> vector<2xi32> + // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<8xi8> + // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported + %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, 3> -> vector<8xi8> + return %0 : vector<8xi8> +} + +// ----- + +// CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi8 +func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> { + // CHECK: %[[RES:.*]] = rocdl.ds.read.tr4.b64 + // CHECK-SAME: -> vector<2xi32> + // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<16xi4> + // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported + %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi4> + return %0 : vector<16xi4> +} + +// ----- + +// CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi8 +func.func @transpose_load_to_rocdl_i6_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi6> { + // CHECK: %[[RES:.*]] = rocdl.ds.read.tr6.b96 + // CHECK-SAME: -> vector<3xi32> + // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<3xi32> to vector<16xi6> + // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported + %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi6> + return %0 : vector<16xi6> +} + +// ----- + +// CHECK-LABEL: func @transpose_load_to_rocdl_i16_memrefxi8 +func.func @transpose_load_to_rocdl_i16_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<4xi16> { + // CHECK: rocdl.ds.read.tr16.b64 + // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported + %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<4xi16> + return %0 : vector<4xi16> +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir new file mode 100644 index 0000000000000..a41051c904ed8 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir @@ -0,0 +1,17 @@ +// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 2>&1 | FileCheck %s + +// ----- + +func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> { + // CHECK: memref to have at least 8 bits element size, got 4 + %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4> + return %0 : vector<16xi4> +} + +// ----- + +func.func @transpose_load_to_rocdl_16xi6(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi6, 3>) -> vector<16xi6> { + // CHECK: memref to have at least 8 bits element size, got 6 + %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<16xi6> + return %0 : vector<16xi6> +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 73306ba6b3f93..6d55583f8bc7c 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -166,3 +166,59 @@ func.func @swizzle_scalable_vec(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> { %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<[4]xf32> func.return %0 : vector<[4]xf32> } + +// ----- + +func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> { + // expected-error@+1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}} + %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16> + func.return %0 : vector<4xf16> +} + +// ----- + +func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> { + // expected-error@+1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}} + %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16> + func.return %0 : vector<4xf16> +} + +// ----- + +func.func @transpose_load_elem_f32(%idx1 : index, %idx2 : index, %mem : memref<128x32xf32, 3>) -> vector<4xf32> { + // expected-error@+1 {{'amdgpu.transpose_load' op Unsupported element type size for transpose load: 32 bits}} + %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf32, 3> -> vector<4xf32> + func.return %0 : vector<4xf32> +} + +// ----- + +func.func @transpose_load_vector_size_f16(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<2xf16> { + // expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 4}} + %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<2xf16> + func.return %0 : vector<2xf16> +} + +// ----- + +func.func @transpose_load_vector_size_i4(%idx1 : index, %idx2 : index, %mem : memref<128x32xi4, 3>) -> vector<20xi4> { + // expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}} + %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi4, 3> -> vector<20xi4> + func.return %0 : vector<20xi4> +} + +// ----- + +func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi8, 3>) -> vector<20xi8> { + // expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 8}} + %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<20xi8> + func.return %0 : vector<20xi8> +} + +// ----- + +func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi6, 3>) -> vector<8xi6> { + // expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}} + %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<8xi6> + func.return %0 : vector<8xi6> +} diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 6c3ffb575f7c2..51f3bbd9ae45c 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -486,3 +486,10 @@ func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : v %0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32> func.return %0 : vector<16xf32> } + +// CHECK-LABEL: func @transpose_load +func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<4xf16> { + // CHECK: amdgpu.transpose_load + %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<4xf16> + func.return %0 : vector<4xf16> +}