diff --git a/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 3acc383923ca..18a24f9c2a3d 100644 --- a/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -522,7 +522,8 @@ def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm", I32EnumAttrCase<"row_mirror", 8>, I32EnumAttrCase<"row_half_mirror", 9>, I32EnumAttrCase<"row_bcast_15", 10>, - I32EnumAttrCase<"row_bcast_31", 11> + I32EnumAttrCase<"row_bcast_31", 11>, + I32EnumAttrCase<"row_share", 12> ]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::amdgpu"; @@ -555,6 +556,7 @@ def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result", - Reverse within a half-row (`row_half_mirror`) - Broadcast the 15th lane of each row to the next row (`row_bcast`) - Broadcast lane 31 to rows 2 and 3 (`row_bcast`) + - Broadcast a lane [0-15] within row 0 to all lanes of row 0 (`row_share`) }]; let results = (outs AnyType:$result); let assemblyFormat = [{ diff --git a/external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index f194e70ee275..78cf07666ddf 100644 --- a/external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -652,6 +652,49 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0], }]; } +// Set Inactive intrinsic +def ROCDL_SetInactiveOp : ROCDL_IntrOp<"set.inactive", [], [0], + [AllTypesMatch<["res", "src", "inactive_value"]>], 1, 0, 0>, + Arguments<(ins LLVM_Type:$src, LLVM_Type:$inactive_value)> { + let results = (outs LLVM_Type:$res); + let assemblyFormat = [{ + attr-dict $src `,` $inactive_value `:` type($src) + }]; + let description = [{ + Copies the given value while setting all inactive lanes to a specified value. + }]; +} + +// Strict WWM intrinsic operation +def ROCDL_StrictWWMOp : ROCDL_IntrOp<"strict.wwm", [], [0], + [AllTypesMatch<["res", "src"]>], 1, 0, 0>, + Arguments<(ins LLVM_Type:$src)> { + let results = (outs LLVM_Type:$res); + let assemblyFormat = [{ + attr-dict $src `:` type($src) + }]; + let description = [{ + Copies the active channels of the source value to the destination value, + guaranteed to be executed in Whole Wavefront Mode with all channels enabled. + }]; +} + +// PermLaneX16 intrinsic operation +def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0], + [AllTypesMatch<["res", "old", "src0"]> , AllTypesMatch<["src1", "src2"]>], 1, 0, 0, + [4, 5], ["fi", "boundControl"]>, + Arguments<(ins LLVM_Type:$old, LLVM_Type:$src0, LLVM_Type:$src1, LLVM_Type:$src2, + I1Attr:$fi, I1Attr:$boundControl)> { + let results = (outs LLVM_Type:$res); + let assemblyFormat = [{ + attr-dict $old `,` $src0 `,` $src1 `,` $src2 `,` $fi `,` $boundControl `:` type($src0) `:` type($src1) + }]; + let description = [{ + Performs a `permlanex16` operation with the given operands, applying the + permutation specified by $fi to the provided inputs. + }]; +} + def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>, BuildableType<"::mlir::VectorType::get(" "{2},$_builder.getI16Type())">; diff --git a/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 949424db7c4d..a7b76da0ee98 100644 --- a/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1125,6 +1125,7 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern { ROW_HALF_MIRROR = 0x141, BCAST15 = 0x142, BCAST31 = 0x143, + ROW_SHARE0 = 0x150 }; auto kind = DppOp.getKind(); @@ -1182,6 +1183,11 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern { case DPPPerm::row_bcast_31: DppCtrl = DppCtrl::BCAST31; break; + case DPPPerm::row_share: + if (auto intAttr = cast(*permArgument)) { + DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHARE0; + } + break; } // Check for row_mask, bank_mask, bound_ctrl if they exist and create diff --git a/external/llvm-project/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/external/llvm-project/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 2b2a167b90c8..f4e103e490aa 100644 --- a/external/llvm-project/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/external/llvm-project/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -455,6 +455,20 @@ LogicalResult DPPOp::verify() { } break; } + + case DPPPerm::row_share: { + if (!permArgument) { + return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) + + "' value not specified"); + } + if (auto intAttr = dyn_cast(permArgument)) { + uint32_t attrValue = intAttr.getInt(); + if (attrValue < 0 || attrValue > 15) { + return emitOpError( + "Attribute value for 'row_share' must be between 0 and 15"); + } + } + } break; } return success(); } diff --git a/external/llvm-project/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir b/external/llvm-project/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir index 14691e73e62d..64b3328b70ab 100644 --- a/external/llvm-project/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir +++ b/external/llvm-project/mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir @@ -137,3 +137,11 @@ func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 { %0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f16 return %0 : f16 } + +func.func @dpp_row_share(%arg0: i32, %arg1: i32) -> i32 { + // CHECK-LABEL: func @dpp_row_share + // CHECK: rocdl.update.dpp %arg0, %arg1 with 351, 15, 15, false : i32 + // CHECK: return %0 : i32 + %0 = amdgpu.dpp %arg0 %arg1 row_share ( 0xf : i32 ) : i32 + return %0 : i32 +} diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 3fc9b6f33d0a..dc02d15ea67f 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1379,7 +1379,6 @@ def Rock_ThreadwiseCopyOp : let hasVerifier = 1; } - def Rock_StageOp : Rock_Op<"stage">, Arguments<(ins DefaultValuedStrAttr:$name)> @@ -1394,4 +1393,17 @@ def Rock_StageOp : let hasCustomAssemblyFormat = 1; } +def Rock_WaveReductionOp : Rock_Op<"wave_reduce">, + Arguments<(ins AnyType:$input, + AnyType:$init, + ReduceMethodAttr:$reduceMethod)> { + let summary = "Wavefront-level reduction using DPP"; + let description = [{ + This op performs a wavefront-level reduction over a vector using DPP logic. + The reduction method is specified via the reduceMethod attribute. + }]; + let results = (outs AnyType:$result); + let assemblyFormat = "$input `,` $init attr-dict `:` type($input) `,` type($init) `->` type($result)"; +} + #endif // ROCK_OPS diff --git a/mlir/include/mlir/Dialect/Rock/Passes.h b/mlir/include/mlir/Dialect/Rock/Passes.h index f8bed4406a13..d4d39d8646e3 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.h +++ b/mlir/include/mlir/Dialect/Rock/Passes.h @@ -47,6 +47,7 @@ namespace rock { #define GEN_PASS_DECL_ROCKSHUFFLEGEMMFORREDUCTIONS #define GEN_PASS_DECL_ROCKGEMMLINALGSPLITKNORMALIZATIONPASS #define GEN_PASS_DECL_ROCKSORTDIMENSIONSMEMORYLAYOUTPASS +#define GEN_PASS_DECL_ROCKWAVEREDUCELOWERINGPASS #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Rock/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 325f329dbd87..1664b6eb4273 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -176,4 +176,16 @@ def RockSortDimensionsMemoryLayoutPass : Pass<"rock-sort-dimensions-memory-layou let dependentDialects = ["rock::RockDialect", "func::FuncDialect", "arith::ArithDialect", "linalg::LinalgDialect"]; } +def RockWaveReduceLoweringPass : Pass<"rock-wave-reduce-lowering", "::mlir::func::FuncOp"> { + let summary = "Lower rock.wave_reduce op into architecture-specific code using DPP or other intrinsics."; + let description = [{ + This pass lowers the `rock.wave_reduce` operation into architecture-specific instructions. + The chipset can be specified via the `--chipset` option to control target-specific lowering. + }]; +let dependentDialects = ["rock::RockDialect","func::FuncDialect", "arith::ArithDialect", "vector::VectorDialect", + "ROCDL::ROCDLDialect", "amdgpu::AMDGPUDialect"]; +let options = [Option<"chipset", "chipset", "std::string", "\"gfx000\"", + "Target GPU architecture to generate intrinsics for (examples: gfx90a, gfx942, gfx1100, etc.).">]; +} + #endif // MLIR_DIALECT_ROCK_PASSES diff --git a/mlir/include/mlir/Dialect/Rock/Pipelines/Pipelines.h b/mlir/include/mlir/Dialect/Rock/Pipelines/Pipelines.h index 102f9651d528..9e5d424af4e1 100644 --- a/mlir/include/mlir/Dialect/Rock/Pipelines/Pipelines.h +++ b/mlir/include/mlir/Dialect/Rock/Pipelines/Pipelines.h @@ -51,6 +51,8 @@ struct KernelOptions : public PassPipelineOptions { PassOptions::Option tuningFallback{ *this, "tuningFallback", desc("Falls back default if invalid config is given"), init(false)}; + PassOptions::Option chip{ + *this, "chip", desc("AMDGPU ISA version: e.g. gfx908"), init("gfx000")}; }; /// Adds the `kernel` pipeline to the `OpPassManager`. diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index c652b60bb24c..c1b486470321 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -175,6 +175,7 @@ void rock::buildKernelPipeline(OpPassManager &pm, funcPm.addPass(createConvertLinalgToAffineLoopsPass()); funcPm.addPass(rock::createRockVectorizeFusionsPass()); } + funcPm.addPass(rock::createRockWaveReduceLoweringPass({options.chip})); funcPm.addPass(rock::createRockReuseLDSPass()); funcPm.addPass(rock::createRockOutputSwizzlePass()); funcPm.addPass(rock::createRockReuseLDSPass()); diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index 4547b148fd0a..1b8bc7bca5e4 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -975,6 +975,7 @@ struct BlockwiseReduceRewritePattern // Get current workitem ID. WorkitemIdOp tid = rewriter.create(loc, rewriter.getIndexType()); + ReduceMethod rMethod = op.getReduceMethod(); // Create strides and bounds to iterate the virtual tensor TransformMapAttr lowerTr = cast( @@ -1098,31 +1099,21 @@ struct BlockwiseReduceRewritePattern vectorTypeOrSelf(elemType, std::max(rIterVectorLen, nrIterVectorLen)), workspaceLDSBuffer, LDSLoadCoords); - Value loadAcc = rewriter.create( - loc, vectorTypeOrSelf(elemType, nrIterVectorLen), accReg, - zeroConstantOp); - Value reduced = createReducingOp(op, loadVal, loadAcc, rewriter); - rewriter.create(loc, reduced, accReg, - zeroConstantOp); - // Storing the last reduction iter output directly to LDS[..., dr=0, - // ...] - Value rIterArg = - reductionLoop.getLowerCoords(/*domain=*/1)[rIterDim]; - Value boundVal = rewriter.create( - loc, threadViewShape[rIterDim]); - Value strideVal = - rewriter.create(loc, rIterVectorLen); - Value lastIterVal = - rewriter.create(loc, boundVal, strideVal); - Value isLastIter = rewriter.create( - loc, arith::CmpIPredicate::eq, rIterArg, lastIterVal); - scf::IfOp ifb = rewriter.create( - loc, isLastIter, /*withElseRegion=*/false); - { - OpBuilder thenb = ifb.getThenBodyBuilder(); - thenb.create( - loc, reduced, workspaceLDSBuffer, - reductionLoop.getLowerCoords(/*domain=*/2)); + Value Reduced; + if ((threadViewShape[rIterDim] / rIterVectorLen) > 1) { + auto waveReductionOp = rewriter.create( + loc, loadVal.getType(), loadVal, initVal, rMethod); + Reduced = waveReductionOp->getResult(0); + + rewriter.create(loc, Reduced, workspaceLDSBuffer, + reductionLoop.getLowerCoords(2)); + } else { + Value loadAcc = rewriter.create( + loc, vectorTypeOrSelf(elemType, nrIterVectorLen), accReg, + zeroConstantOp); + Reduced = createReducingOp(op, loadVal, loadAcc, rewriter); + rewriter.create(loc, Reduced, workspaceLDSBuffer, + reductionLoop.getLowerCoords(2)); } } } diff --git a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt index 5cfd2ef02bec..1f0490ee332f 100644 --- a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt @@ -29,6 +29,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms ShuffleGemmForReductions.cpp GemmLinalgSplitkNormalizationPass.cpp SortDimensionsMemoryLayout.cpp + RockWaveReduce.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Rock @@ -67,4 +68,5 @@ add_rocmlir_dialect_library(MLIRRockTransforms MLIRMHAL MLIRMHALTransforms MLIRCopyOpInterface + MLIRAMDGPUUtils ) diff --git a/mlir/lib/Dialect/Rock/Transforms/RockWaveReduce.cpp b/mlir/lib/Dialect/Rock/Transforms/RockWaveReduce.cpp new file mode 100644 index 000000000000..71a18641d7dc --- /dev/null +++ b/mlir/lib/Dialect/Rock/Transforms/RockWaveReduce.cpp @@ -0,0 +1,184 @@ +//===- RockWaveReduce.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers the `rock.wave_reduce` operation into architecture-specific +// intrinsics such as DPP or other wavefront-level instructions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/Dialect/Rock/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace rock { +#define GEN_PASS_DEF_ROCKWAVEREDUCELOWERINGPASS +#include "mlir/Dialect/Rock/Passes.h.inc" +} // namespace rock +} // namespace mlir + +#define DEBUG_TYPE "rock-wave-reduce-lowering" + +using namespace mlir; +using namespace mlir::rock; +using namespace mlir::amdgpu; + +namespace { + +struct RockWaveReduceLoweringPass + : public rock::impl::RockWaveReduceLoweringPassBase< + RockWaveReduceLoweringPass> { + using RockWaveReduceLoweringPassBase::RockWaveReduceLoweringPassBase; + void runOnOperation() override; +}; + +} // namespace + +struct RockWaveReduceRewritePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + RockWaveReduceRewritePattern(MLIRContext *context, Chipset chipset) + : OpRewritePattern(context), chipset(chipset) {} + Chipset chipset; + + Value createReducingOp(WaveReductionOp op, Value input, Value acc, + OpBuilder &builder) const { + + ReduceMethod rMethod = op.getReduceMethod(); + Location loc = op.getLoc(); + auto vecType = dyn_cast(op.getInput().getType()); + assert(vecType && "Expected input to be a vector type"); + Type elementType = vecType.getElementType(); + if (rMethod == ReduceMethod::Sum) { + Value reduced; + if (elementType.isIntOrIndex()) { + reduced = builder.create(loc, acc, input); + } else { + reduced = builder.create(loc, acc, input); + } + return reduced; + } else { + assert(rMethod == ReduceMethod::Max); + Value reduced; + if (elementType.isIntOrIndex()) { + reduced = builder.create(loc, acc, input); + } else { + reduced = builder.create(loc, acc, input); + } + return reduced; + } + } + + LogicalResult matchAndRewrite(WaveReductionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getInput(); + Value defaultVal = op.getInit(); + auto vecType = dyn_cast(input.getType()); + auto vecLen = vecType.getNumElements(); + Value ReducedAll; + + // Perform intra-wavefront reduction using DPP row_shr shifts. + // Reduce within the row using shifts {1, 2, 3, 4, 8}. + // First shifts (1–3) use the input value, others use accumulated + // result. The result is combined using `createReducingOp`. + // Bank mask (0xF) enables all four banks (each 4 lanes) in the row. + // Row mask (0xF) enables all 4 rows in the wavefront (each row has 16 + // lanes). + + for (int64_t i = 0; i < vecLen; ++i) { + Value scalarVal = rewriter.create( + loc, input, rewriter.create(loc, i)); + Value setInactiveScalar = rewriter.create( + loc, vecType.getElementType(), scalarVal, defaultVal); + std::array row_shifts = {1, 2, 3, 4, 8}; + Value dppResult = setInactiveScalar; + Value BrodcastAll; + + for (int shift : row_shifts) { + Value input = (shift <= 3) ? setInactiveScalar : dppResult; + auto dppOp = rewriter.create( + loc, vecType.getElementType(), input, input, + amdgpu::DPPPermAttr::get(rewriter.getContext(), + amdgpu::DPPPerm::row_shr), + rewriter.getI32IntegerAttr(shift), rewriter.getI32IntegerAttr(0xF), + rewriter.getI32IntegerAttr(0xF), rewriter.getBoolAttr(false)); + + dppResult = createReducingOp(op, dppResult, dppOp, rewriter); + } + + // Broadcast the reduced value across the entire wavefront. + // Chipset version determines the broadcast method used. + if (chipset.majorVersion == 9) { + auto makeDPP = [&](amdgpu::DPPPerm perm, int rowMask, int bankMask) { + auto dppOp = rewriter.create( + loc, vecType.getElementType(), dppResult, dppResult, + amdgpu::DPPPermAttr::get(rewriter.getContext(), perm), nullptr, + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), + rewriter.getBoolAttr(false)); + dppResult = createReducingOp(op, dppResult, dppOp, rewriter); + return dppOp; + }; + makeDPP(amdgpu::DPPPerm::row_bcast_15, 0xA, 0xF); + makeDPP(amdgpu::DPPPerm::row_bcast_31, 0xC, 0xF); + BrodcastAll = makeDPP(amdgpu::DPPPerm::wave_ror, 0xF, 0xF); + + } else if (chipset.majorVersion >= 10) { + Value src1Value = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), -1)); + Value src2Value = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), -1)); + + BrodcastAll = rewriter.create( + loc, vecType.getElementType(), dppResult, dppResult, src1Value, + src2Value, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + + } else if (chipset.majorVersion == 0) { + return failure(); + } + + // Final result is reduced into lane 0 using WWM, then broadcast to all + // lanes. + ReducedAll = rewriter.create( + loc, vecType.getElementType(), BrodcastAll); + ReducedAll = rewriter.create( + loc, vecType.getElementType(), ReducedAll, + rewriter.create(loc, 0, 32)); + } + rewriter.replaceOp(op, ReducedAll); + return success(); + } +}; + +void RockWaveReduceLoweringPass::runOnOperation() { + FailureOr maybeChipset = Chipset::parse(chipset); + if (failed(maybeChipset)) { + emitError(UnknownLoc::get(&getContext()), + "Invalid chipset name: " + chipset); + return signalPassFailure(); + } + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), *maybeChipset); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + signalPassFailure(); + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir b/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir index 07c800af5a16..0b5a49b39296 100644 --- a/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir +++ b/mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir @@ -54,11 +54,9 @@ // CHECK: memref.store %[[NEGINF]], %[[TO_REDUCE_ACC_MEMREF:.*]][{{.*}}] // CHECK: rock.transforming_for {{.*}} (%[[LD_COORD:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP5]], #[[TMAP12]]](%[[TID0]], %[[ZERO]], %[[ZERO]]), {{.*}}, (%[[LDS_ST_COORD:.*]]) = [#[[TMAP9]], #[[TMAP10]], #[[TMAP11]], #[[TMAP13]], #[[TMAP12]]](%[[TID0]], %[[ZERO]], %[[ZERO]]) {{.*}} bounds [1, 1, 20] strides [1, 1, 4] { - // CHECK: %[[TO_REDUCE_VAL:.*]] = rock.in_bounds_load {{.*}}[%[[LD_COORD]]] - // CHECK: %[[TO_REDUCE_ACC:.*]] = rock.in_bounds_load %[[TO_REDUCE_ACC_MEMREF]][%[[ZERO]]] - // CHECK: %[[MAX_REDUCE:.*]] = vector.reduction , %[[TO_REDUCE_VAL]] : vector<4xf32> into f32 - // CHECK: %[[ACC_NEW:.*]] = arith.maximumf %[[TO_REDUCE_ACC]], %[[MAX_REDUCE]] - // CHECK: rock.in_bounds_store %[[ACC_NEW]] -> %arg2[%[[LDS_ST_COORD]]] + // CHECK: %[[TO_REDUCE_VAL:.*]] = rock.in_bounds_load %arg2[%[[LD_COORD]]] : memref<400xf32, #gpu.address_space>, index -> vector<4xf32> + // CHECK: %[[MAX_REDUCE:.*]] = rock.wave_reduce %[[TO_REDUCE_VAL:.*]], %[[NEGINF]] {{.*}} : vector<4xf32>, f32 -> vector<4xf32> + // CHECK: rock.in_bounds_store %[[MAX_REDUCE:.*]] -> %arg2[%[[LDS_ST_COORD]]] : vector<4xf32> -> memref<400xf32, #gpu.address_space>, index // CHECK: rock.lds_barrier // CHECK: rock.threadwise_read_into {{.*}}(%arg2) {{.*}} -> %arg1 diff --git a/mlir/test/rocmlir-driver/pipelines.mlir b/mlir/test/rocmlir-driver/pipelines.mlir index 981e87c8192a..f8fd170127d0 100644 --- a/mlir/test/rocmlir-driver/pipelines.mlir +++ b/mlir/test/rocmlir-driver/pipelines.mlir @@ -27,6 +27,7 @@ // GPU-NEXT:canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, // GPU-NEXT:convert-linalg-to-affine-loops, // GPU-NEXT:rock-vectorize-fusions, +// GPU-NEXT:rock-wave-reduce-lowering{chipset=gfx90a}, // GPU-NEXT:rock-reuse-lds, // GPU-NEXT:rock-output-swizzle, // GPU-NEXT:rock-reuse-lds, diff --git a/mlir/tools/rocmlir-driver/rocmlir-driver.cpp b/mlir/tools/rocmlir-driver/rocmlir-driver.cpp index 3bd64df22b1c..b5539462d198 100644 --- a/mlir/tools/rocmlir-driver/rocmlir-driver.cpp +++ b/mlir/tools/rocmlir-driver/rocmlir-driver.cpp @@ -172,7 +172,9 @@ runKernelPipeline(StringRef arch, ModuleOp kmod, bool isHighLevel, } if (kernelPipelineSet.contains("gpu")) { // Set up the default lowering pipeline which goes down to GPU dialect. - rock::buildKernelPipeline(pm); + rock::KernelOptions opts; + opts.chip = devName.getChip().str(); + rock::buildKernelPipeline(pm, opts); } bool isRocdlOnly = kernelPipelineSet.contains("rocdl") && !kernelPipelineSet.contains("binary");