-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][amdgpu] Define an amdgpu.scaling_mfma wrapper #137498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
c43bc26
3ba7ea8
846c389
02f5d98
8a9face
b38e93f
80bf07d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -833,6 +833,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) { | |
| mfma.getBlocks(), chipset); | ||
| } | ||
|
|
||
| static std::optional<std::tuple<StringRef, uint32_t, uint32_t>> | ||
| mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { | ||
| return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than creating an overloaded function, you can pass the operation* and then do the dyn_cast + if_else.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the benefit of branching on the operation type?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be templatized as well; my point was to remove the unnecessary function that does the same thing.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, templating doesn't remove the need for branching since I would still be branching off the class of the input op to pass in the appropriate arguments (they differ between the proposed scaled_mfma and the existing mfma). So, I opted for the branching on op type which I find clearer.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the reason this isn't a pure |
||
| smfma.getSourceB().getType(), | ||
| smfma.getDestC().getType(), smfma.getM(), | ||
| smfma.getN(), smfma.getK(), 1u, chipset); | ||
| } | ||
|
|
||
| /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` | ||
| /// if one exists. This includes checking to ensure the intrinsic is supported | ||
| /// on the architecture you are compiling for. | ||
|
|
@@ -954,6 +962,54 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> { | |
| } | ||
| }; | ||
|
|
||
| struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> { | ||
| ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) | ||
| : ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {} | ||
Muzammiluddin-Syed-ECE marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Chipset chipset; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| Location loc = op.getLoc(); | ||
| Type outType = typeConverter->convertType(op.getDestD().getType()); | ||
| Type intrinsicOutType = outType; | ||
| if (auto outVecType = dyn_cast<VectorType>(outType)) | ||
| if (outVecType.getElementType().isBF16()) | ||
Muzammiluddin-Syed-ECE marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| intrinsicOutType = outVecType.clone(rewriter.getI16Type()); | ||
|
|
||
| if (chipset.majorVersion != 9 || chipset < kGfx908) | ||
Muzammiluddin-Syed-ECE marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return op->emitOpError("Scaled MFMA only supported on gfx908+"); | ||
| std::optional<std::tuple<StringRef, uint32_t, uint32_t>> | ||
| maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset); | ||
| if (!maybeScaledIntrinsic.has_value()) | ||
| return op.emitOpError( | ||
| "no intrinsic matching Scaled MFMA size on given chipset"); | ||
Muzammiluddin-Syed-ECE marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic); | ||
| OperationState loweredOp(loc, intrinsicName); | ||
| loweredOp.addTypes(intrinsicOutType); | ||
| loweredOp.addOperands( | ||
| {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()), | ||
| convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()), | ||
| adaptor.getDestC()}); | ||
| Value scaleA = createI32Constant(rewriter, loc, adaptor.getScaleA()); | ||
Muzammiluddin-Syed-ECE marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB()); | ||
| Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA()); | ||
| Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB()); | ||
| auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; | ||
Muzammiluddin-Syed-ECE marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode), | ||
| createI32Constant(rewriter, loc, bTypeCode), | ||
| /*scale A byte=*/opselA, /*scale A=*/scaleA, | ||
| /*scale B byte=*/opselB, /*scale B=*/scaleB}); | ||
| Value lowered = rewriter.create(loweredOp)->getResult(0); | ||
| if (outType != intrinsicOutType) | ||
| lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered); | ||
| rewriter.replaceOp(op, lowered); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { | ||
| WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) | ||
| : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {} | ||
|
|
@@ -1474,8 +1530,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, | |
| RawBufferOpLowering<RawBufferAtomicCmpswapOp, | ||
| ROCDL::RawPtrBufferAtomicCmpSwap>, | ||
| AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, | ||
| MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, | ||
| PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, | ||
| GatherToLDSOpLowering>(converter, chipset); | ||
| MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering, | ||
| ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering, | ||
| PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter, | ||
| chipset); | ||
| patterns.add<AMDGPUSwizzleBitModeLowering>(converter); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.