diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 697cb35a59a28..237aab4d7f309 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -27,7 +27,7 @@ using namespace mlir::nvgpu; #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" -void nvgpu::NVGPUDialect::initialize() { +void NVGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc" @@ -42,7 +42,7 @@ void nvgpu::NVGPUDialect::initialize() { >(); } -bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { +bool NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { if (!memorySpace) return false; if (auto intAttr = llvm::dyn_cast(memorySpace)) @@ -52,7 +52,7 @@ bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { return false; } -bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { +bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { Attribute memorySpace = type.getMemorySpace(); return isSharedMemoryAddressSpace(memorySpace); } @@ -140,7 +140,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue matrixC, const std::array &mmaShape, bool tf32Enabled, bool sparse = false) { - // The verification for mma.sync covering various shapes and data types is // based on the fundamental tensor core shape. @@ -292,7 +291,6 @@ LogicalResult MmaSparseSyncOp::verify() { // NVGPU_LdMatrixOp //===----------------------------------------------------------------------===// LogicalResult LdMatrixOp::verify() { - // ldmatrix reads data from source in shared memory auto srcMemref = llvm::cast(getSrcMemref().getType()); @@ -345,7 +343,7 @@ LogicalResult LdMatrixOp::verify() { // NVGPU_TmaAsyncLoadOp //===----------------------------------------------------------------------===// -unsigned getSwizzleBytes(TensorMapSwizzleKind kind) { +static unsigned getSwizzleBytes(TensorMapSwizzleKind kind) { switch (kind) { case TensorMapSwizzleKind::SWIZZLE_32B: return 32; @@ -359,7 +357,7 @@ unsigned getSwizzleBytes(TensorMapSwizzleKind kind) { } std::optional verifyTmaDescriptorWithMemref( - Operation *op, nvgpu::TensorMapDescriptorType descType, + Operation *op, TensorMapDescriptorType descType, std::optional memrefType = std::nullopt) { MemRefType descMemref = descType.getTensor(); // Limitation @@ -655,8 +653,7 @@ LogicalResult WarpgroupMmaStoreOp::verify() { //===----------------------------------------------------------------------===// LogicalResult WarpgroupMmaInitAccumulatorOp::verify() { - - nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType(); + WarpgroupAccumulatorType accType = getMatrixC().getType(); int64_t sizeM = accType.getFragmented().getDimSize(0); int64_t sizeN = accType.getFragmented().getDimSize(1); Type elemType = accType.getFragmented().getElementType(); diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 46e82bd8fc8c8..2a857eddbb932 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -43,7 +43,7 @@ using namespace mlir::transform; // Apply...ConversionPatternsOp //===----------------------------------------------------------------------===// -void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( +void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { auto &llvmTypeConverter = static_cast(typeConverter); /// device-side async tokens cannot be materialized in nvvm. We just @@ -62,62 +62,58 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( llvm_unreachable("unknown address space enum value"); return static_cast(NVVM::NVVMMemorySpace::Generic); }); - llvmTypeConverter.addConversion( - [&](nvgpu::DeviceAsyncTokenType type) -> Type { - return llvmTypeConverter.convertType( - IntegerType::get(type.getContext(), 32)); - }); - llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { + llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) -> Type { + return llvmTypeConverter.convertType( + IntegerType::get(type.getContext(), 32)); + }); + llvmTypeConverter.addConversion([&](MBarrierTokenType type) -> Type { return llvmTypeConverter.convertType( IntegerType::get(type.getContext(), 64)); }); - llvmTypeConverter.addConversion( - [&](nvgpu::WarpgroupAccumulatorType type) -> Type { - Type elemType = type.getFragmented().getElementType(); - int64_t sizeM = type.getFragmented().getDimSize(0); - int64_t sizeN = type.getFragmented().getDimSize(1); - - unsigned numMembers; - if (elemType.isF32() || elemType.isInteger(32)) - numMembers = sizeN / 2; - else if (elemType.isF16()) - numMembers = sizeN / 4; - else - llvm_unreachable("unsupported type for warpgroup accumulator"); - - SmallVector innerStructBody; - for (unsigned i = 0; i < numMembers; i++) - innerStructBody.push_back(elemType); - auto innerStructType = LLVM::LLVMStructType::getLiteral( - type.getContext(), innerStructBody); - - SmallVector structBody; - for (int i = 0; i < sizeM; i += kWgmmaSizeM) - structBody.push_back(innerStructType); - - auto convertedType = - LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); - return llvmTypeConverter.convertType(convertedType); - }); - llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type { + llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) -> Type { + Type elemType = type.getFragmented().getElementType(); + int64_t sizeM = type.getFragmented().getDimSize(0); + int64_t sizeN = type.getFragmented().getDimSize(1); + + unsigned numMembers; + if (elemType.isF32() || elemType.isInteger(32)) + numMembers = sizeN / 2; + else if (elemType.isF16()) + numMembers = sizeN / 4; + else + llvm_unreachable("unsupported type for warpgroup accumulator"); + + SmallVector innerStructBody; + for (unsigned i = 0; i < numMembers; i++) + innerStructBody.push_back(elemType); + auto innerStructType = + LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody); + + SmallVector structBody; + for (int i = 0; i < sizeM; i += kWgmmaSizeM) + structBody.push_back(innerStructType); + + auto convertedType = + LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); + return llvmTypeConverter.convertType(convertedType); + }); + llvmTypeConverter.addConversion([&](MBarrierGroupType type) -> Type { return llvmTypeConverter.convertType( getMBarrierMemrefType(type.getContext(), type)); }); llvmTypeConverter.addConversion( - [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { + [&](WarpgroupMatrixDescriptorType type) -> Type { return llvmTypeConverter.convertType( IntegerType::get(type.getContext(), 64)); }); - llvmTypeConverter.addConversion( - [&](nvgpu::TensorMapDescriptorType type) -> Type { - return LLVM::LLVMPointerType::get(type.getContext()); - }); + llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) -> Type { + return LLVM::LLVMPointerType::get(type.getContext()); + }); populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns); } -LogicalResult -transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( - transform::TypeConverterBuilderOpInterface builder) { +LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( + TypeConverterBuilderOpInterface builder) { if (builder.getTypeConverterType() != "LLVMTypeConverter") return emitOpError("expected LLVMTypeConverter"); return success(); @@ -127,17 +123,18 @@ transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( // CreateAsyncGroupsOp //===---------------------------------------------------------------------===// -void transform::CreateAsyncGroupsOp::getEffects( +void CreateAsyncGroupsOp::getEffects( SmallVectorImpl &effects) { - transform::consumesHandle(getTargetMutable(), effects); - transform::producesHandle(getOperation()->getOpResults(), effects); - transform::modifiesPayload(effects); + consumesHandle(getTargetMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); } -DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne( - TransformRewriter &rewriter, Operation *target, - ApplyToEachResultList &results, TransformState &state) { - nvgpu::createAsyncGroups(rewriter, target, getBypassL1()); +DiagnosedSilenceableFailure +CreateAsyncGroupsOp::applyToOne(TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, + TransformState &state) { + createAsyncGroups(rewriter, target, getBypassL1()); results.push_back(target); return DiagnosedSilenceableFailure::success(); } @@ -218,7 +215,7 @@ collectStage0PipeliningOps(scf::ForOp forOp, continue; } - if (isa(op)) { + if (isa(op)) { ops.insert(&op); ops.insert(std::make_move_iterator(barriers.begin()), std::make_move_iterator(barriers.end())); @@ -246,7 +243,7 @@ setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, unsigned iteration, unsigned depth) { // Based on the order of copies within the loop we need to set the number // of copies in flight, unless it is already set. - auto waitOp = dyn_cast(op); + auto waitOp = dyn_cast(op); if (!waitOp || waitOp.getNumGroups()) return; @@ -312,13 +309,12 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, // original number of iterations, in particular side-effect free operations // and barriers, even if they cannot be predicated. if (isMemoryEffectFree(op) || - isa(op)) { + isa(op)) { return op; } // Otherwise, only async copies can currently be predicated. - auto asyncCopyOp = dyn_cast(op); + auto asyncCopyOp = dyn_cast(op); if (!asyncCopyOp) return nullptr; @@ -335,8 +331,8 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0); auto srcElements = arith::SelectOp::create(rewriter, loc, predicate, originalSrcElement, c0Index); - auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create( - rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), + auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create( + rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()), asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, UnitAttr()); @@ -805,17 +801,16 @@ FailureOr MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { rhsIndexFn, rhsShape); Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef, resIndexFn, resShape); - res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, - info.tf32Enabled); + res = + MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled); buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn, resShape); return res.getDefiningOp(); } -DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( - transform::TransformRewriter &rewriter, LinalgOp linalgOp, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure RewriteMatmulAsMmaSyncOp::applyToOne( + TransformRewriter &rewriter, LinalgOp linalgOp, + ApplyToEachResultList &results, TransformState &state) { bool fail = true; // TODO: more robust detection of matmulOp, with transposes etc. if (isa_and_nonnull(linalgOp.getOperation())) { @@ -854,43 +849,42 @@ struct HopperBuilder { HopperBuilder(RewriterBase &rewriter, Location loc) : rewriter(rewriter), loc(loc) {} - TypedValue + TypedValue buildAndInitBarrierInSharedMemory(OpFoldResult numThreads); /// Create tma descriptor op to initiate transfer from global to shared /// memory. This must be done before the launch op, on the host. - TypedValue + TypedValue buildGlobalMemRefDescriptor(TypedValue memref, gpu::LaunchOp launchOp); /// Build a tma load from global memory to shared memory using `barrier` to /// synchronize. Return the number of bytes that will be transferred. - OpFoldResult - buildTmaAsyncLoad(TypedValue globalDesc, - TypedValue sharedMemref, - TypedValue barrier, - SmallVectorImpl &loadOps); - void buildBarrierArriveTx(TypedValue barrier, + OpFoldResult buildTmaAsyncLoad(TypedValue globalDesc, + TypedValue sharedMemref, + TypedValue barrier, + SmallVectorImpl &loadOps); + void buildBarrierArriveTx(TypedValue barrier, ArrayRef sizes); /// If threadIdx.x == 0 does TMA request + wait, else just wait. /// Return the operation that performs the transfer on thread0. // TODO: In the future, don't hardcode to thread 0 but elect a leader. SmallVector buildPredicateLoadsOnThread0( - ArrayRef> globalDescriptors, + ArrayRef> globalDescriptors, ArrayRef> sharedMemBuffers, - TypedValue barrier); + TypedValue barrier); - void buildTryWaitParity(TypedValue barrier); + void buildTryWaitParity(TypedValue barrier); RewriterBase &rewriter; Location loc; }; SmallVector HopperBuilder::buildPredicateLoadsOnThread0( - ArrayRef> globalDescriptors, + ArrayRef> globalDescriptors, ArrayRef> sharedMemBuffers, - TypedValue barrier) { + TypedValue barrier) { SmallVector loadOps; Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x); @@ -931,22 +925,22 @@ static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { // return b.getI64IntegerAttr(static_cast(kSharedMemorySpace)); } -TypedValue +TypedValue HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); - Value barrier = nvgpu::MBarrierCreateOp::create( + Value barrier = MBarrierCreateOp::create( rewriter, loc, - nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); + MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); nvgpu::MBarrierInitOp::create( rewriter, loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero, Value()); gpu::BarrierOp::create(rewriter, loc); - return cast>(barrier); + return cast>(barrier); } -TypedValue +TypedValue HopperBuilder::buildGlobalMemRefDescriptor(TypedValue memref, gpu::LaunchOp launchOp) { OpBuilder::InsertionGuard guard(rewriter); @@ -962,29 +956,29 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue memref, getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes); auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); - Value desc = nvgpu::TmaCreateDescriptorOp::create( + Value desc = TmaCreateDescriptorOp::create( rewriter, loc, - nvgpu::TensorMapDescriptorType::get( - rewriter.getContext(), - MemRefType::Builder(memref.getType()) - .setMemorySpace(sharedMemorySpace), - TensorMapSwizzleKind::SWIZZLE_NONE, - TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO, - TensorMapInterleaveKind::INTERLEAVE_NONE), + TensorMapDescriptorType::get(rewriter.getContext(), + MemRefType::Builder(memref.getType()) + .setMemorySpace(sharedMemorySpace), + TensorMapSwizzleKind::SWIZZLE_NONE, + TensorMapL2PromoKind::L2PROMO_NONE, + TensorMapOOBKind::OOB_ZERO, + TensorMapInterleaveKind::INTERLEAVE_NONE), unrankedMemRef, sizes); - return cast>(desc); + return cast>(desc); } -OpFoldResult HopperBuilder::buildTmaAsyncLoad( - TypedValue globalDesc, - TypedValue sharedMemref, - TypedValue barrier, - SmallVectorImpl &loadOps) { +OpFoldResult +HopperBuilder::buildTmaAsyncLoad(TypedValue globalDesc, + TypedValue sharedMemref, + TypedValue barrier, + SmallVectorImpl &loadOps) { MLIRContext *ctx = rewriter.getContext(); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); - Operation *loadOp = nvgpu::TmaAsyncLoadOp::create( - rewriter, loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, - zero, Value(), Value()); + Operation *loadOp = + TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc, + ValueRange{zero, zero}, zero, Value(), Value()); loadOps.push_back(loadOp); auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref); SmallVector symbols(mixedSizes.size()); @@ -997,9 +991,8 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad( return res; } -void HopperBuilder::buildBarrierArriveTx( - TypedValue barrier, - ArrayRef mixedSizes) { +void HopperBuilder::buildBarrierArriveTx(TypedValue barrier, + ArrayRef mixedSizes) { assert(!mixedSizes.empty() && "expecte non-empty sizes"); MLIRContext *ctx = rewriter.getContext(); SmallVector symbols(mixedSizes.size()); @@ -1013,8 +1006,7 @@ void HopperBuilder::buildBarrierArriveTx( Value()); } -void HopperBuilder::buildTryWaitParity( - TypedValue barrier) { +void HopperBuilder::buildTryWaitParity(TypedValue barrier) { Type i1 = rewriter.getI1Type(); Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0); // 10M is an arbitrary, not too small or too big number to specify the number @@ -1058,11 +1050,11 @@ SmallVector CopyBuilder::rewrite(ArrayRef copyOps) { ArrayRef{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(), launchOp.getBlockSizeZ()}); - TypedValue barrier = + TypedValue barrier = buildAndInitBarrierInSharedMemory(numThreads); SmallVector> shmems; - SmallVector> globalDescs; + SmallVector> globalDescs; for (Operation *op : copyOps) { auto copyOp = cast(op); auto inMemRef = @@ -1071,7 +1063,7 @@ SmallVector CopyBuilder::rewrite(ArrayRef copyOps) { "expected in to be a 2D memref"); // 2. Build global memory descriptor. - TypedValue globalDesc = + TypedValue globalDesc = buildGlobalMemRefDescriptor(inMemRef, launchOp); globalDescs.push_back(globalDesc); @@ -1098,9 +1090,8 @@ SmallVector CopyBuilder::rewrite(ArrayRef copyOps) { } DiagnosedSilenceableFailure -transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, - transform::TransformResults &results, - transform::TransformState &state) { +RewriteCopyAsTmaOp::apply(TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { auto payloadOps = state.getPayloadOps(getTarget()); gpu::LaunchOp commonLaunchOp; Operation *firstOp, *failingOp; @@ -1137,15 +1128,14 @@ transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, namespace { class NVGPUTransformDialectExtension - : public transform::TransformDialectExtension< - NVGPUTransformDialectExtension> { + : public TransformDialectExtension { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension) NVGPUTransformDialectExtension() { declareGeneratedDialect(); declareGeneratedDialect(); - declareGeneratedDialect(); + declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); registerTransformOps< diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp index 5b89c87e2c1ff..7f626a625aaea 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -64,6 +64,5 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern { void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns( RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) { - patterns.add(patterns.getContext(), precision); } diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp index 809d634f05905..9e5ea93769cdc 100644 --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -168,8 +168,7 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType) { Type elementType = fragmentType.vectorType.getElementType(); ArrayRef operandShape = fragmentType.vectorType.getShape(); - FailureOr regInfo = - getMmaSyncRegisterType(fragmentType); + FailureOr regInfo = getMmaSyncRegisterType(fragmentType); if (failed(regInfo)) return failure(); @@ -199,8 +198,8 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, (logicalValueIdDim % elementsPerRegister)}); } -FailureOr -nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) { +FailureOr nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, + bool transpose) { LdMatrixParams params; Type elType = type.vectorType.getElementType(); params.fragmentType = type.vectorType;