Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,8 @@ def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
I32EnumAttrCase<"row_mirror", 8>,
I32EnumAttrCase<"row_half_mirror", 9>,
I32EnumAttrCase<"row_bcast_15", 10>,
I32EnumAttrCase<"row_bcast_31", 11>
I32EnumAttrCase<"row_bcast_31", 11>,
I32EnumAttrCase<"row_share", 12>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::amdgpu";
Expand Down Expand Up @@ -555,6 +556,7 @@ def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result",
- Reverse within a half-row (`row_half_mirror`)
- Broadcast the 15th lane of each row to the next row (`row_bcast`)
- Broadcast lane 31 to rows 2 and 3 (`row_bcast`)
- Broadcast a lane [0-15] within row 0 to all lanes of row 0 (`row_share`)
}];
let results = (outs AnyType:$result);
let assemblyFormat = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,49 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
}];
}

// Set Inactive intrinsic
def ROCDL_SetInactiveOp : ROCDL_IntrOp<"set.inactive", [], [0],
[AllTypesMatch<["res", "src", "inactive_value"]>], 1, 0, 0>,
Arguments<(ins LLVM_Type:$src, LLVM_Type:$inactive_value)> {
let results = (outs LLVM_Type:$res);
let assemblyFormat = [{
attr-dict $src `,` $inactive_value `:` type($src)
}];
let description = [{
Copies the given value while setting all inactive lanes to a specified value.
}];
}

// Strict WWM intrinsic operation
def ROCDL_StrictWWMOp : ROCDL_IntrOp<"strict.wwm", [], [0],
[AllTypesMatch<["res", "src"]>], 1, 0, 0>,
Arguments<(ins LLVM_Type:$src)> {
let results = (outs LLVM_Type:$res);
let assemblyFormat = [{
attr-dict $src `:` type($src)
}];
let description = [{
Copies the active channels of the source value to the destination value,
guaranteed to be executed in Whole Wavefront Mode with all channels enabled.
}];
}

// PermLaneX16 intrinsic operation
def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
[AllTypesMatch<["res", "old", "src0", "src1", "src2"]>], 1, 0, 0,
[4, 5], ["fi", "boundControl"]>,
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src0, LLVM_Type:$src1, LLVM_Type:$src2,
I1Attr:$fi, I1Attr:$boundControl)> {
let results = (outs LLVM_Type:$res);
let assemblyFormat = [{
attr-dict $old `,` $src0 `,` $src1 `,` $src2 `,` $fi `,` $boundControl `:` type($src0)
}];
let description = [{
Performs a `permlanex16` operation with the given operands, applying the
permutation specified by $fi to the provided inputs.
}];
}

def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getI16Type())">;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,7 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
ROW_HALF_MIRROR = 0x141,
BCAST15 = 0x142,
BCAST31 = 0x143,
ROW_SHARE0 = 0x150
};

auto kind = DppOp.getKind();
Expand Down Expand Up @@ -1182,6 +1183,11 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
case DPPPerm::row_bcast_31:
DppCtrl = DppCtrl::BCAST31;
break;
case DPPPerm::row_share:
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHARE0;
}
break;
}

// Check for row_mask, bank_mask, bound_ctrl if they exist and create
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,20 @@ LogicalResult DPPOp::verify() {
}
break;
}

case DPPPerm::row_share: {
if (!permArgument) {
return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
"' value not specified");
}
if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
uint32_t attrValue = intAttr.getInt();
if (attrValue < 0 || attrValue > 15) {
return emitOpError(
"Attribute value for 'row_share' must be between 0 and 15");
}
}
} break;
}
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,11 @@ func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f16
return %0 : f16
}

func.func @dpp_row_share(%arg0: i32, %arg1: i32) -> i32 {
// CHECK-LABEL: func @dpp_row_share
// CHECK: rocdl.update.dpp %arg0, %arg1 with 351, 15, 15, false : i32
// CHECK: return %0 : i32
%0 = amdgpu.dpp %arg0 %arg1 row_share ( 0xf : i32 ) : i32
return %0 : i32
}
163 changes: 133 additions & 30 deletions mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"

#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Rock/IR/AccelEmitter.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"

namespace mlir {
namespace rock {
#define GEN_PASS_DEF_ROCKBLOCKWISEGEMMTOTHREADWISEPASS
Expand Down Expand Up @@ -1098,31 +1098,132 @@ struct BlockwiseReduceRewritePattern
vectorTypeOrSelf(elemType,
std::max(rIterVectorLen, nrIterVectorLen)),
workspaceLDSBuffer, LDSLoadCoords);
Value loadAcc = rewriter.create<InBoundsLoadOp>(
loc, vectorTypeOrSelf(elemType, nrIterVectorLen), accReg,
zeroConstantOp);
Value reduced = createReducingOp(op, loadVal, loadAcc, rewriter);
rewriter.create<InBoundsStoreOp>(loc, reduced, accReg,
zeroConstantOp);
// Storing the last reduction iter output directly to LDS[..., dr=0,
// ...]
Value rIterArg =
reductionLoop.getLowerCoords(/*domain=*/1)[rIterDim];
Value boundVal = rewriter.create<arith::ConstantIndexOp>(
loc, threadViewShape[rIterDim]);
Value strideVal =
rewriter.create<arith::ConstantIndexOp>(loc, rIterVectorLen);
Value lastIterVal =
rewriter.create<arith::SubIOp>(loc, boundVal, strideVal);
Value isLastIter = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, rIterArg, lastIterVal);
scf::IfOp ifb = rewriter.create<scf::IfOp>(
loc, isLastIter, /*withElseRegion=*/false);
{
OpBuilder thenb = ifb.getThenBodyBuilder();
thenb.create<InBoundsStoreOp>(
loc, reduced, workspaceLDSBuffer,
reductionLoop.getLowerCoords(/*domain=*/2));

Value BrodcastAll;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to have rocdl and amdgpu dialects here. Can we have some new op rock::wavereduction or something like that? Then, we can lower it to rocdl later.

Copy link
Contributor

@dhernandez0 dhernandez0 Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, that way it's easier to keep the current implementation if dpp is not supported.

if ((threadViewShape[rIterDim] / rIterVectorLen) > 1) {
auto vecType = dyn_cast<VectorType>(loadVal.getType());
auto vecLen = vecType.getNumElements();
SmallVector<Value, 4> scalarDppResults;

for (int64_t i = 0; i < vecLen; ++i) {

Value scalarVal = rewriter.create<vector::ExtractElementOp>(
loc, loadVal,
rewriter.create<arith::ConstantIndexOp>(loc, i));
Value scalarInactiveValue = rewriter.create<arith::ConstantOp>(
loc, vecType.getElementType(),
rewriter.getFloatAttr(vecType.getElementType(), 0.0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for max reduction it should be -inf?


Value setInactiveScalar = rewriter.create<ROCDL::SetInactiveOp>(
loc, vecType.getElementType(), scalarVal,
scalarInactiveValue);

Value dppResult1 = rewriter.create<amdgpu::DPPOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add comments to understand what the constants are doing here: 0xF, ...

loc, elemType, setInactiveScalar, setInactiveScalar,
amdgpu::DPPPermAttr::get(rewriter.getContext(),
amdgpu::DPPPerm::row_shr),
rewriter.getI32IntegerAttr(1),
rewriter.getI32IntegerAttr(0xF),
rewriter.getI32IntegerAttr(0xF),
rewriter.getBoolAttr(false));

Value dppResult = createReducingOp(op, setInactiveScalar,
dppResult1, rewriter);

Value dppResult2 = rewriter.create<amdgpu::DPPOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add comments to understand what the constants are doing here: 0xF, ...

loc, elemType, setInactiveScalar, setInactiveScalar,
amdgpu::DPPPermAttr::get(rewriter.getContext(),
amdgpu::DPPPerm::row_shr),
rewriter.getI32IntegerAttr(2),
rewriter.getI32IntegerAttr(0xF),
rewriter.getI32IntegerAttr(0xF),
rewriter.getBoolAttr(false));

dppResult =
createReducingOp(op, dppResult, dppResult2, rewriter);

Value dppResult3 = rewriter.create<amdgpu::DPPOp>(
loc, elemType, setInactiveScalar, setInactiveScalar,
amdgpu::DPPPermAttr::get(rewriter.getContext(),
amdgpu::DPPPerm::row_shr),
rewriter.getI32IntegerAttr(3),
rewriter.getI32IntegerAttr(0xF),
rewriter.getI32IntegerAttr(0xF),
rewriter.getBoolAttr(false));

dppResult =
createReducingOp(op, dppResult, dppResult3, rewriter);

Value dppResult4 = rewriter.create<amdgpu::DPPOp>(
loc, elemType, dppResult, dppResult,
amdgpu::DPPPermAttr::get(rewriter.getContext(),
amdgpu::DPPPerm::row_shr),
rewriter.getI32IntegerAttr(4),
rewriter.getI32IntegerAttr(0xF),
rewriter.getI32IntegerAttr(0xE),
rewriter.getBoolAttr(false));

dppResult =
createReducingOp(op, dppResult, dppResult4, rewriter);
Value dppResult5 = rewriter.create<amdgpu::DPPOp>(
loc, elemType, dppResult, dppResult,
amdgpu::DPPPermAttr::get(rewriter.getContext(),
amdgpu::DPPPerm::row_shr),
rewriter.getI32IntegerAttr(8),
rewriter.getI32IntegerAttr(0xF),
rewriter.getI32IntegerAttr(0xC),
rewriter.getBoolAttr(false));

dppResult =
createReducingOp(op, dppResult, dppResult5, rewriter);

Value dppBrodcast = rewriter.create<amdgpu::DPPOp>(
loc, elemType, dppResult, dppResult,
amdgpu::DPPPermAttr::get(rewriter.getContext(),
amdgpu::DPPPerm::row_bcast_15),
nullptr, rewriter.getI32IntegerAttr(0xA),
rewriter.getI32IntegerAttr(0xF),
rewriter.getBoolAttr(false));

dppResult =
createReducingOp(op, dppResult, dppBrodcast, rewriter);

dppBrodcast = rewriter.create<amdgpu::DPPOp>(
loc, elemType, dppResult, dppResult,
amdgpu::DPPPermAttr::get(rewriter.getContext(),
amdgpu::DPPPerm::row_bcast_31),
nullptr, rewriter.getI32IntegerAttr(0xC),
rewriter.getI32IntegerAttr(0xF),
rewriter.getBoolAttr(false));

dppResult =
createReducingOp(op, dppResult, dppBrodcast, rewriter);

Value dppRotated = rewriter.create<amdgpu::DPPOp>(
loc, elemType, dppResult, dppResult,
amdgpu::DPPPermAttr::get(rewriter.getContext(),
amdgpu::DPPPerm::wave_ror),
nullptr, rewriter.getI32IntegerAttr(0xF),
rewriter.getI32IntegerAttr(0xF),
rewriter.getBoolAttr(false));

dppRotated = rewriter.create<ROCDL::StrictWWMOp>(loc, elemType,
dppRotated);
BrodcastAll = rewriter.create<ROCDL::ReadlaneOp>(
loc, elemType, dppRotated,
rewriter.create<mlir::arith::ConstantIntOp>(loc, 0, 32));
}
rewriter.create<InBoundsStoreOp>(loc, BrodcastAll,
workspaceLDSBuffer,
reductionLoop.getLowerCoords(2));
} else {
Value loadAcc = rewriter.create<InBoundsLoadOp>(
loc, vectorTypeOrSelf(elemType, nrIterVectorLen), accReg,
zeroConstantOp);
BrodcastAll = createReducingOp(op, loadVal, loadAcc, rewriter);
rewriter.create<InBoundsStoreOp>(loc, BrodcastAll,
workspaceLDSBuffer,
reductionLoop.getLowerCoords(2));
}
}
}
Expand Down Expand Up @@ -1319,9 +1420,10 @@ void RockLowerBlockwiseGemmToThreadwisePass::runOnOperation() {
{
ConversionTarget writeAllTarget(*ctx);
writeAllTarget.addIllegalOp<BlockwiseBroadcastReduceOp, BlockwiseFillOp>();
writeAllTarget.addLegalDialect<arith::ArithDialect, rock::RockDialect,
memref::MemRefDialect, scf::SCFDialect,
vector::VectorDialect, AffineDialect>();
writeAllTarget.addLegalDialect<
arith::ArithDialect, rock::RockDialect, memref::MemRefDialect,
scf::SCFDialect, vector::VectorDialect, AffineDialect,
ROCDL::ROCDLDialect, amdgpu::AMDGPUDialect>();
writeAllTarget.addLegalOp<gpu::PrintfOp>();
RewritePatternSet writeAllPatterns(ctx);
writeAllPatterns
Expand All @@ -1335,7 +1437,8 @@ void RockLowerBlockwiseGemmToThreadwisePass::runOnOperation() {
target.addIllegalOp<FillOp, BlockwiseGemmOp, BlockwiseGemmAccelOp>();
target.addLegalDialect<arith::ArithDialect, rock::RockDialect,
affine::AffineDialect, vector::VectorDialect,
memref::MemRefDialect>();
memref::MemRefDialect, ROCDL::ROCDLDialect,
amdgpu::AMDGPUDialect>();
target.addLegalOp<gpu::PrintfOp>();

RewritePatternSet patterns(ctx);
Expand Down
Loading