Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,8 @@ def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
I32EnumAttrCase<"row_mirror", 8>,
I32EnumAttrCase<"row_half_mirror", 9>,
I32EnumAttrCase<"row_bcast_15", 10>,
I32EnumAttrCase<"row_bcast_31", 11>
I32EnumAttrCase<"row_bcast_31", 11>,
I32EnumAttrCase<"row_share", 12>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::amdgpu";
Expand Down Expand Up @@ -555,6 +556,7 @@ def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result",
- Reverse within a half-row (`row_half_mirror`)
- Broadcast the 15th lane of each row to the next row (`row_bcast`)
- Broadcast lane 31 to rows 2 and 3 (`row_bcast`)
- Broadcast a lane [0-15] within row 0 to all lanes of row 0 (`row_share`)
}];
let results = (outs AnyType:$result);
let assemblyFormat = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,49 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
}];
}

// Set Inactive intrinsic
def ROCDL_SetInactiveOp : ROCDL_IntrOp<"set.inactive", [], [0],
[AllTypesMatch<["res", "src", "inactive_value"]>], 1, 0, 0>,
Arguments<(ins LLVM_Type:$src, LLVM_Type:$inactive_value)> {
let results = (outs LLVM_Type:$res);
let assemblyFormat = [{
attr-dict $src `,` $inactive_value `:` type($src)
}];
let description = [{
Copies the given value while setting all inactive lanes to a specified value.
}];
}

// Strict WWM intrinsic operation
def ROCDL_StrictWWMOp : ROCDL_IntrOp<"strict.wwm", [], [0],
[AllTypesMatch<["res", "src"]>], 1, 0, 0>,
Arguments<(ins LLVM_Type:$src)> {
let results = (outs LLVM_Type:$res);
let assemblyFormat = [{
attr-dict $src `:` type($src)
}];
let description = [{
Copies the active channels of the source value to the destination value,
guaranteed to be executed in Whole Wavefront Mode with all channels enabled.
}];
}

// PermLaneX16 intrinsic operation
def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
[AllTypesMatch<["res", "old", "src0"]> , AllTypesMatch<["src1", "src2"]>], 1, 0, 0,
[4, 5], ["fi", "boundControl"]>,
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src0, LLVM_Type:$src1, LLVM_Type:$src2,
I1Attr:$fi, I1Attr:$boundControl)> {
let results = (outs LLVM_Type:$res);
let assemblyFormat = [{
attr-dict $old `,` $src0 `,` $src1 `,` $src2 `,` $fi `,` $boundControl `:` type($src0) `:` type($src1)
}];
let description = [{
Performs a `permlanex16` operation with the given operands, applying the
permutation specified by $fi to the provided inputs.
}];
}

def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getI16Type())">;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,7 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
ROW_HALF_MIRROR = 0x141,
BCAST15 = 0x142,
BCAST31 = 0x143,
ROW_SHARE0 = 0x150
};

auto kind = DppOp.getKind();
Expand Down Expand Up @@ -1182,6 +1183,11 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
case DPPPerm::row_bcast_31:
DppCtrl = DppCtrl::BCAST31;
break;
case DPPPerm::row_share:
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHARE0;
}
break;
}

// Check for row_mask, bank_mask, bound_ctrl if they exist and create
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,20 @@ LogicalResult DPPOp::verify() {
}
break;
}

case DPPPerm::row_share: {
if (!permArgument) {
return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
"' value not specified");
}
if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
uint32_t attrValue = intAttr.getInt();
if (attrValue < 0 || attrValue > 15) {
return emitOpError(
"Attribute value for 'row_share' must be between 0 and 15");
}
}
} break;
}
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,11 @@ func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f16
return %0 : f16
}

func.func @dpp_row_share(%arg0: i32, %arg1: i32) -> i32 {
// CHECK-LABEL: func @dpp_row_share
// CHECK: rocdl.update.dpp %arg0, %arg1 with 351, 15, 15, false : i32
// CHECK: return %0 : i32
%0 = amdgpu.dpp %arg0 %arg1 row_share ( 0xf : i32 ) : i32
return %0 : i32
}
14 changes: 13 additions & 1 deletion mlir/include/mlir/Dialect/Rock/IR/RockOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1379,7 +1379,6 @@ def Rock_ThreadwiseCopyOp :
let hasVerifier = 1;
}


def Rock_StageOp :
Rock_Op<"stage">,
Arguments<(ins DefaultValuedStrAttr<StrAttr, "stage">:$name)>
Expand All @@ -1394,4 +1393,17 @@ def Rock_StageOp :
let hasCustomAssemblyFormat = 1;
}

def Rock_WaveReductionOp : Rock_Op<"wave_reduce">,
Arguments<(ins AnyType:$input,
AnyType:$init,
ReduceMethodAttr:$reduceMethod)> {
let summary = "Wavefront-level reduction using DPP";
let description = [{
This op performs a wavefront-level reduction over a vector using DPP logic.
The reduction method is specified via the reduceMethod attribute.
}];
let results = (outs AnyType:$result);
let assemblyFormat = "$input `,` $init attr-dict `:` type($input) `,` type($init) `->` type($result)";
}

#endif // ROCK_OPS
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Rock/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ namespace rock {
#define GEN_PASS_DECL_ROCKSHUFFLEGEMMFORREDUCTIONS
#define GEN_PASS_DECL_ROCKGEMMLINALGSPLITKNORMALIZATIONPASS
#define GEN_PASS_DECL_ROCKSORTDIMENSIONSMEMORYLAYOUTPASS
#define GEN_PASS_DECL_ROCKWAVEREDUCELOWERINGPASS

#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Rock/Passes.h.inc"
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Rock/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,16 @@ def RockSortDimensionsMemoryLayoutPass : Pass<"rock-sort-dimensions-memory-layou
let dependentDialects = ["rock::RockDialect", "func::FuncDialect", "arith::ArithDialect", "linalg::LinalgDialect"];
}

def RockWaveReduceLoweringPass : Pass<"rock-wave-reduce-lowering", "::mlir::func::FuncOp"> {
let summary = "Lower rock.wave_reduce op into architecture-specific code using DPP or other intrinsics.";
let description = [{
This pass lowers the `rock.wave_reduce` operation into architecture-specific instructions.
The chipset can be specified via the `--chipset` option to control target-specific lowering.
}];
let dependentDialects = ["rock::RockDialect","func::FuncDialect", "arith::ArithDialect", "vector::VectorDialect",
"ROCDL::ROCDLDialect", "amdgpu::AMDGPUDialect"];
let options = [Option<"chipset", "chipset", "std::string", "\"gfx000\"",
"Target GPU architecture to generate intrinsics for (examples: gfx90a, gfx942, gfx1100, etc.).">];
}

#endif // MLIR_DIALECT_ROCK_PASSES
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Rock/Pipelines/Pipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ struct KernelOptions : public PassPipelineOptions<KernelOptions> {
PassOptions::Option<bool> tuningFallback{
*this, "tuningFallback",
desc("Falls back default if invalid config is given"), init(false)};
PassOptions::Option<std::string> chip{
*this, "chip", desc("AMDGPU ISA version: e.g. gfx908"), init("gfx000")};
};

/// Adds the `kernel` pipeline to the `OpPassManager`.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ void rock::buildKernelPipeline(OpPassManager &pm,
funcPm.addPass(createConvertLinalgToAffineLoopsPass());
funcPm.addPass(rock::createRockVectorizeFusionsPass());
}
funcPm.addPass(rock::createRockWaveReduceLoweringPass({options.chip}));
funcPm.addPass(rock::createRockReuseLDSPass());
funcPm.addPass(rock::createRockOutputSwizzlePass());
funcPm.addPass(rock::createRockReuseLDSPass());
Expand Down
41 changes: 16 additions & 25 deletions mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,7 @@ struct BlockwiseReduceRewritePattern
// Get current workitem ID.
WorkitemIdOp tid =
rewriter.create<WorkitemIdOp>(loc, rewriter.getIndexType());
ReduceMethod rMethod = op.getReduceMethod();

// Create strides and bounds to iterate the virtual tensor
TransformMapAttr lowerTr = cast<TransformMapAttr>(
Expand Down Expand Up @@ -1098,31 +1099,21 @@ struct BlockwiseReduceRewritePattern
vectorTypeOrSelf(elemType,
std::max(rIterVectorLen, nrIterVectorLen)),
workspaceLDSBuffer, LDSLoadCoords);
Value loadAcc = rewriter.create<InBoundsLoadOp>(
loc, vectorTypeOrSelf(elemType, nrIterVectorLen), accReg,
zeroConstantOp);
Value reduced = createReducingOp(op, loadVal, loadAcc, rewriter);
rewriter.create<InBoundsStoreOp>(loc, reduced, accReg,
zeroConstantOp);
// Storing the last reduction iter output directly to LDS[..., dr=0,
// ...]
Value rIterArg =
reductionLoop.getLowerCoords(/*domain=*/1)[rIterDim];
Value boundVal = rewriter.create<arith::ConstantIndexOp>(
loc, threadViewShape[rIterDim]);
Value strideVal =
rewriter.create<arith::ConstantIndexOp>(loc, rIterVectorLen);
Value lastIterVal =
rewriter.create<arith::SubIOp>(loc, boundVal, strideVal);
Value isLastIter = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, rIterArg, lastIterVal);
scf::IfOp ifb = rewriter.create<scf::IfOp>(
loc, isLastIter, /*withElseRegion=*/false);
{
OpBuilder thenb = ifb.getThenBodyBuilder();
thenb.create<InBoundsStoreOp>(
loc, reduced, workspaceLDSBuffer,
reductionLoop.getLowerCoords(/*domain=*/2));
Value Reduced;
if ((threadViewShape[rIterDim] / rIterVectorLen) > 1) {
auto waveReductionOp = rewriter.create<rock::WaveReductionOp>(
loc, loadVal.getType(), loadVal, initVal, rMethod);
Reduced = waveReductionOp->getResult(0);

rewriter.create<InBoundsStoreOp>(loc, Reduced, workspaceLDSBuffer,
reductionLoop.getLowerCoords(2));
} else {
Value loadAcc = rewriter.create<InBoundsLoadOp>(
loc, vectorTypeOrSelf(elemType, nrIterVectorLen), accReg,
zeroConstantOp);
Reduced = createReducingOp(op, loadVal, loadAcc, rewriter);
rewriter.create<InBoundsStoreOp>(loc, Reduced, workspaceLDSBuffer,
reductionLoop.getLowerCoords(2));
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms
ShuffleGemmForReductions.cpp
GemmLinalgSplitkNormalizationPass.cpp
SortDimensionsMemoryLayout.cpp
RockWaveReduce.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Rock
Expand Down Expand Up @@ -67,4 +68,5 @@ add_rocmlir_dialect_library(MLIRRockTransforms
MLIRMHAL
MLIRMHALTransforms
MLIRCopyOpInterface
MLIRAMDGPUUtils
)
Loading
Loading