-
Notifications
You must be signed in to change notification settings - Fork 50
[DRAFT] Implement DPP Reduction in wavefront #1796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 6 commits
6384eea
12b98ab
c291f81
2c24cf8
f9d7417
df7dd42
e9888e1
88b9639
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,10 +39,10 @@ | |
| #include "mlir/IR/BuiltinTypes.h" | ||
| #include "mlir/Transforms/DialectConversion.h" | ||
|
|
||
| #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" | ||
| #include "mlir/Dialect/Rock/IR/AccelEmitter.h" | ||
| #include "llvm/ADT/SmallVector.h" | ||
| #include "llvm/Support/Debug.h" | ||
|
|
||
| namespace mlir { | ||
| namespace rock { | ||
| #define GEN_PASS_DEF_ROCKBLOCKWISEGEMMTOTHREADWISEPASS | ||
|
|
@@ -1098,31 +1098,132 @@ struct BlockwiseReduceRewritePattern | |
| vectorTypeOrSelf(elemType, | ||
| std::max(rIterVectorLen, nrIterVectorLen)), | ||
| workspaceLDSBuffer, LDSLoadCoords); | ||
| Value loadAcc = rewriter.create<InBoundsLoadOp>( | ||
| loc, vectorTypeOrSelf(elemType, nrIterVectorLen), accReg, | ||
| zeroConstantOp); | ||
| Value reduced = createReducingOp(op, loadVal, loadAcc, rewriter); | ||
| rewriter.create<InBoundsStoreOp>(loc, reduced, accReg, | ||
| zeroConstantOp); | ||
| // Storing the last reduction iter output directly to LDS[..., dr=0, | ||
| // ...] | ||
| Value rIterArg = | ||
| reductionLoop.getLowerCoords(/*domain=*/1)[rIterDim]; | ||
| Value boundVal = rewriter.create<arith::ConstantIndexOp>( | ||
| loc, threadViewShape[rIterDim]); | ||
| Value strideVal = | ||
| rewriter.create<arith::ConstantIndexOp>(loc, rIterVectorLen); | ||
| Value lastIterVal = | ||
| rewriter.create<arith::SubIOp>(loc, boundVal, strideVal); | ||
| Value isLastIter = rewriter.create<arith::CmpIOp>( | ||
| loc, arith::CmpIPredicate::eq, rIterArg, lastIterVal); | ||
| scf::IfOp ifb = rewriter.create<scf::IfOp>( | ||
| loc, isLastIter, /*withElseRegion=*/false); | ||
| { | ||
| OpBuilder thenb = ifb.getThenBodyBuilder(); | ||
| thenb.create<InBoundsStoreOp>( | ||
| loc, reduced, workspaceLDSBuffer, | ||
| reductionLoop.getLowerCoords(/*domain=*/2)); | ||
|
|
||
| Value BrodcastAll; | ||
| if ((threadViewShape[rIterDim] / rIterVectorLen) > 1) { | ||
| auto vecType = dyn_cast<VectorType>(loadVal.getType()); | ||
| auto vecLen = vecType.getNumElements(); | ||
| SmallVector<Value, 4> scalarDppResults; | ||
|
|
||
| for (int64_t i = 0; i < vecLen; ++i) { | ||
|
|
||
| Value scalarVal = rewriter.create<vector::ExtractElementOp>( | ||
| loc, loadVal, | ||
| rewriter.create<arith::ConstantIndexOp>(loc, i)); | ||
| Value scalarInactiveValue = rewriter.create<arith::ConstantOp>( | ||
| loc, vecType.getElementType(), | ||
| rewriter.getFloatAttr(vecType.getElementType(), 0.0)); | ||
|
||
|
|
||
| Value setInactiveScalar = rewriter.create<ROCDL::SetInactiveOp>( | ||
| loc, vecType.getElementType(), scalarVal, | ||
| scalarInactiveValue); | ||
|
|
||
| Value dppResult1 = rewriter.create<amdgpu::DPPOp>( | ||
|
||
| loc, elemType, setInactiveScalar, setInactiveScalar, | ||
| amdgpu::DPPPermAttr::get(rewriter.getContext(), | ||
| amdgpu::DPPPerm::row_shr), | ||
| rewriter.getI32IntegerAttr(1), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getBoolAttr(false)); | ||
|
|
||
| Value dppResult = createReducingOp(op, setInactiveScalar, | ||
| dppResult1, rewriter); | ||
|
|
||
| Value dppResult2 = rewriter.create<amdgpu::DPPOp>( | ||
|
||
| loc, elemType, setInactiveScalar, setInactiveScalar, | ||
| amdgpu::DPPPermAttr::get(rewriter.getContext(), | ||
| amdgpu::DPPPerm::row_shr), | ||
| rewriter.getI32IntegerAttr(2), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getBoolAttr(false)); | ||
|
|
||
| dppResult = | ||
| createReducingOp(op, dppResult, dppResult2, rewriter); | ||
|
|
||
| Value dppResult3 = rewriter.create<amdgpu::DPPOp>( | ||
| loc, elemType, setInactiveScalar, setInactiveScalar, | ||
| amdgpu::DPPPermAttr::get(rewriter.getContext(), | ||
| amdgpu::DPPPerm::row_shr), | ||
| rewriter.getI32IntegerAttr(3), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getBoolAttr(false)); | ||
|
|
||
| dppResult = | ||
| createReducingOp(op, dppResult, dppResult3, rewriter); | ||
|
|
||
| Value dppResult4 = rewriter.create<amdgpu::DPPOp>( | ||
| loc, elemType, dppResult, dppResult, | ||
| amdgpu::DPPPermAttr::get(rewriter.getContext(), | ||
| amdgpu::DPPPerm::row_shr), | ||
| rewriter.getI32IntegerAttr(4), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getI32IntegerAttr(0xE), | ||
| rewriter.getBoolAttr(false)); | ||
|
|
||
| dppResult = | ||
| createReducingOp(op, dppResult, dppResult4, rewriter); | ||
| Value dppResult5 = rewriter.create<amdgpu::DPPOp>( | ||
| loc, elemType, dppResult, dppResult, | ||
| amdgpu::DPPPermAttr::get(rewriter.getContext(), | ||
| amdgpu::DPPPerm::row_shr), | ||
| rewriter.getI32IntegerAttr(8), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getI32IntegerAttr(0xC), | ||
| rewriter.getBoolAttr(false)); | ||
|
|
||
| dppResult = | ||
| createReducingOp(op, dppResult, dppResult5, rewriter); | ||
|
|
||
| Value dppBrodcast = rewriter.create<amdgpu::DPPOp>( | ||
| loc, elemType, dppResult, dppResult, | ||
| amdgpu::DPPPermAttr::get(rewriter.getContext(), | ||
| amdgpu::DPPPerm::row_bcast_15), | ||
| nullptr, rewriter.getI32IntegerAttr(0xA), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getBoolAttr(false)); | ||
|
|
||
| dppResult = | ||
| createReducingOp(op, dppResult, dppBrodcast, rewriter); | ||
|
|
||
| dppBrodcast = rewriter.create<amdgpu::DPPOp>( | ||
| loc, elemType, dppResult, dppResult, | ||
| amdgpu::DPPPermAttr::get(rewriter.getContext(), | ||
| amdgpu::DPPPerm::row_bcast_31), | ||
| nullptr, rewriter.getI32IntegerAttr(0xC), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getBoolAttr(false)); | ||
|
|
||
| dppResult = | ||
| createReducingOp(op, dppResult, dppBrodcast, rewriter); | ||
|
|
||
| Value dppRotated = rewriter.create<amdgpu::DPPOp>( | ||
| loc, elemType, dppResult, dppResult, | ||
| amdgpu::DPPPermAttr::get(rewriter.getContext(), | ||
| amdgpu::DPPPerm::wave_ror), | ||
| nullptr, rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getI32IntegerAttr(0xF), | ||
| rewriter.getBoolAttr(false)); | ||
|
|
||
| dppRotated = rewriter.create<ROCDL::StrictWWMOp>(loc, elemType, | ||
| dppRotated); | ||
| BrodcastAll = rewriter.create<ROCDL::ReadlaneOp>( | ||
| loc, elemType, dppRotated, | ||
| rewriter.create<mlir::arith::ConstantIntOp>(loc, 0, 32)); | ||
| } | ||
| rewriter.create<InBoundsStoreOp>(loc, BrodcastAll, | ||
| workspaceLDSBuffer, | ||
| reductionLoop.getLowerCoords(2)); | ||
| } else { | ||
| Value loadAcc = rewriter.create<InBoundsLoadOp>( | ||
| loc, vectorTypeOrSelf(elemType, nrIterVectorLen), accReg, | ||
| zeroConstantOp); | ||
| BrodcastAll = createReducingOp(op, loadVal, loadAcc, rewriter); | ||
| rewriter.create<InBoundsStoreOp>(loc, BrodcastAll, | ||
| workspaceLDSBuffer, | ||
| reductionLoop.getLowerCoords(2)); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -1319,9 +1420,10 @@ void RockLowerBlockwiseGemmToThreadwisePass::runOnOperation() { | |
| { | ||
| ConversionTarget writeAllTarget(*ctx); | ||
| writeAllTarget.addIllegalOp<BlockwiseBroadcastReduceOp, BlockwiseFillOp>(); | ||
| writeAllTarget.addLegalDialect<arith::ArithDialect, rock::RockDialect, | ||
| memref::MemRefDialect, scf::SCFDialect, | ||
| vector::VectorDialect, AffineDialect>(); | ||
| writeAllTarget.addLegalDialect< | ||
| arith::ArithDialect, rock::RockDialect, memref::MemRefDialect, | ||
| scf::SCFDialect, vector::VectorDialect, AffineDialect, | ||
| ROCDL::ROCDLDialect, amdgpu::AMDGPUDialect>(); | ||
| writeAllTarget.addLegalOp<gpu::PrintfOp>(); | ||
| RewritePatternSet writeAllPatterns(ctx); | ||
| writeAllPatterns | ||
|
|
@@ -1335,7 +1437,8 @@ void RockLowerBlockwiseGemmToThreadwisePass::runOnOperation() { | |
| target.addIllegalOp<FillOp, BlockwiseGemmOp, BlockwiseGemmAccelOp>(); | ||
| target.addLegalDialect<arith::ArithDialect, rock::RockDialect, | ||
| affine::AffineDialect, vector::VectorDialect, | ||
| memref::MemRefDialect>(); | ||
| memref::MemRefDialect, ROCDL::ROCDLDialect, | ||
| amdgpu::AMDGPUDialect>(); | ||
| target.addLegalOp<gpu::PrintfOp>(); | ||
|
|
||
| RewritePatternSet patterns(ctx); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need to have rocdl and amdgpu dialects here. Can we have some new op rock::wavereduction or something like that? Then, we can lower it to rocdl later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, that way it's easier to keep the current implementation if dpp is not supported.