diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index 6fea10185402a..488f358ff3802 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -76,6 +76,17 @@ LayoutAttr getLayoutAttr(const Value value); /// it will check the operand itself and its defining op. LayoutAttr getLayoutAttr(const OpOperand &opr); +/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists. +template || + std::is_same_v>> +void removeLayoutAttr(const T &operandOrResult); + +/// Removes the LayoutAttr for each OpOperand and OpResult of the given +/// operation if they exist. If the operation contains regions, it is also +/// applied recursively to the contained operations +void removeLayoutAttrs(Operation *op); + /// Sets the LayoutAttr for a given OpOperand or OpResult by attaching /// it to the owner's dictionary attributes template -removeTemporaryLayoutAttributes(ArrayRef attrs) { - SmallVector newAttrs; - for (NamedAttribute attr : attrs) { - if (!isa(attr.getValue())) - newAttrs.push_back(attr); - } - return newAttrs; -} - /// Helper function to check if the layout is packed. Layout is packed if it is /// 2D and lane_data[0] != 1 (data packed from col dimension). static bool hasPackedLayout(xegpu::LayoutAttr layout) { @@ -197,9 +184,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0 return isa(op); })) return failure(); - // Create a new function with the same signature. + // Create a new function with the same signature and same attributes. + SmallVector workgroupAttributionsTypes = + llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(), + [](BlockArgument arg) { return arg.getType(); }); + SmallVector privateAttributionsTypes = + llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(), + [](BlockArgument arg) { return arg.getType(); }); auto newGpuFunc = rewriter.create( - gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType()); + gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType(), + workgroupAttributionsTypes, privateAttributionsTypes); + newGpuFunc->setAttrs(gpuFuncOp->getAttrs()); // Create a WarpExecuteOnLane0Op with same arguments and results as the // original gpuFuncOp. rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front()); @@ -265,13 +260,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0 /// ``` struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; - LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = - getWarpResult(subgroupOp, llvm::IsaPred); + getWarpResult(warpOp, llvm::IsaPred); if (!operand) return rewriter.notifyMatchFailure( - subgroupOp, "warp result is not a xegpu::CreateNdDesc op"); + warpOp, "warp result is not a xegpu::CreateNdDesc op"); auto descOp = operand->get().getDefiningOp(); unsigned operandIdx = operand->getOperandNumber(); @@ -288,9 +283,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern { newYieldValues.push_back(operand); newYieldTypes.push_back(operand.getType()); } - rewriter.setInsertionPoint(subgroupOp); + rewriter.setInsertionPoint(warpOp); gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, subgroupOp, /* new yieled values = */ newYieldValues, + rewriter, warpOp, /* new yieled values = */ newYieldValues, /* new yielded types = */ newYieldTypes, newRetIndices); SmallVector newDescOperands; @@ -347,10 +342,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern { /// ``` struct StoreNdDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; - LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { auto yield = cast( - subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); Operation *lastNode = yield->getPrevNode(); auto storeOp = dyn_cast_or_null(lastNode); if (!storeOp) @@ -372,7 +367,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern { SmallVector newRetIndices; gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, subgroupOp, + rewriter, warpOp, /* new yielded values = */ ValueRange{storeOp.getValue(), storeOp.getTensorDesc()}, /* new yielded types = */ @@ -403,9 +398,9 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern { resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]), distributedTensorDescTy, rewriter)); - rewriter.create( - newWarpOp.getLoc(), TypeRange{}, newStoreOperands, - removeTemporaryLayoutAttributes(storeOp->getAttrs())); + auto newStoreOp = rewriter.create( + newWarpOp.getLoc(), TypeRange{}, newStoreOperands, storeOp->getAttrs()); + xegpu::removeLayoutAttrs(newStoreOp); rewriter.eraseOp(storeOp); return success(); } @@ -449,21 +444,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern { /// ``` struct LoadNdDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; - LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *operand = - getWarpResult(subgroupOp, llvm::IsaPred); + OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) { + if (!isa(op)) + return false; + // Make sure the same load op is the last operation in the warp op body. + // This ensure that load op is not sinked earlier violating any barrier + // synchronizations. + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + return yield->getPrevNode() == op; + }); + if (!operand) return rewriter.notifyMatchFailure( - subgroupOp, "warp result is not a xegpu::LoadNd op"); - // Make sure the load op is the last operation in the warp op body. This - // ensure that load op is not sinked earlier violating any barrier - // synchronizations. - auto yield = cast( - subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); - Operation *lastNode = yield->getPrevNode(); - if (!dyn_cast_or_null(lastNode)) - return failure(); + warpOp, "warp result is not a xegpu::LoadNd op"); auto loadOp = operand->get().getDefiningOp(); xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType(); @@ -474,11 +470,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { unsigned operandIdx = operand->getOperandNumber(); VectorType distributedTypeByWarpOp = - cast(subgroupOp.getResult(operandIdx).getType()); + cast(warpOp.getResult(operandIdx).getType()); SmallVector newRetIndices; gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, subgroupOp, + rewriter, warpOp, /* new yielded values = */ loadOp.getTensorDesc(), /* new yielded types = */ tensorDescTy, newRetIndices); @@ -498,7 +494,8 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(), resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]), distributedTensorDescTy, rewriter), - removeTemporaryLayoutAttributes(loadOp->getAttrs())); + loadOp->getAttrs()); + xegpu::removeLayoutAttrs(newLoadOp); // Set the packed attribute if the layout requires it. newLoadOp.setPacked(hasPackedLayout(layout)); Value distributedVal = newWarpOp.getResult(operandIdx); @@ -548,12 +545,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { /// ``` struct DpasDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; - LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *operand = - getWarpResult(subgroupOp, llvm::IsaPred); + OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred); if (!operand) - return rewriter.notifyMatchFailure(subgroupOp, + return rewriter.notifyMatchFailure(warpOp, "warp result is not a xegpu::Dpas op"); auto dpasOp = operand->get().getDefiningOp(); @@ -599,7 +595,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern { // Create a new warp op without the dpas. SmallVector newRetIndices; gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices); + rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices); FailureOr expectedDistLhsTyOrFailure = xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA); @@ -630,14 +626,16 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern { resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]), newDpasOperandExpectedTypes[i], rewriter)); } - Value newDpasOp = rewriter.create( - newWarpOp->getLoc(), distributedResultTy, newDpasOperands, - removeTemporaryLayoutAttributes(dpasOp->getAttrs())); + auto newDpasOp = + rewriter.create(newWarpOp->getLoc(), distributedResultTy, + newDpasOperands, dpasOp->getAttrs()); + xegpu::removeLayoutAttrs(newDpasOp); Value distributedVal = newWarpOp.getResult(operandIdx); // Resolve the output type. - newDpasOp = resolveDistributedTy( - newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter); - rewriter.replaceAllUsesWith(distributedVal, newDpasOp); + Value typeResolved = + resolveDistributedTy(newDpasOp.getResult(), + distResultTypeByWarpOpOrFailure.value(), rewriter); + rewriter.replaceAllUsesWith(distributedVal, typeResolved); return success(); } }; @@ -678,13 +676,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern { /// ``` struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; - LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = - getWarpResult(subgroupOp, llvm::IsaPred); + getWarpResult(warpOp, llvm::IsaPred); if (!operand) return rewriter.notifyMatchFailure( - subgroupOp, "warp result is not a xegpu::UpdateNdOffset op"); + warpOp, "warp result is not a xegpu::UpdateNdOffset op"); auto updateOp = operand->get().getDefiningOp(); unsigned operandIdx = operand->getOperandNumber(); // new update op does not have layout attribute. @@ -703,7 +701,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern { } SmallVector newRetIndices; gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices); + rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); SmallVector newUpdateOperands; for (size_t i : newRetIndices) { @@ -717,14 +715,15 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern { } } // Create a new update op outside the warp op. - Value newUpdateOp = rewriter.create( + auto newUpdateOp = rewriter.create( newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands, - removeTemporaryLayoutAttributes(updateOp->getAttrs())); + updateOp->getAttrs()); + xegpu::removeLayoutAttrs(newUpdateOp); Value distributedVal = newWarpOp.getResult(operandIdx); // Resolve the distributed type with the original type. - newUpdateOp = - resolveDistributedTy(newUpdateOp, distributedVal.getType(), rewriter); - rewriter.replaceAllUsesWith(distributedVal, newUpdateOp); + Value typeResolved = resolveDistributedTy( + newUpdateOp.getResult(), distributedVal.getType(), rewriter); + rewriter.replaceAllUsesWith(distributedVal, typeResolved); return success(); } }; @@ -758,10 +757,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern { /// ``` struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; - LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { auto yield = cast( - subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); Operation *lastNode = yield->getPrevNode(); auto prefetchOp = dyn_cast_or_null(lastNode); if (!prefetchOp) @@ -775,7 +774,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { SmallVector newYieldTypes = {prefetchOp.getTensorDescType()}; SmallVector newRetIndices; gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices); + rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices); // Create a new prefetch op outside the warp op with updated tensor // descriptor type. Source tensor descriptor require type resolution. xegpu::TensorDescType newTensorDescTy = @@ -783,9 +782,10 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); SmallVector newPrefetchOperands = {resolveDistributedTy( newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)}; - rewriter.create( - newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands, - removeTemporaryLayoutAttributes(prefetchOp->getAttrs())); + rewriter.create(newWarpOp.getLoc(), TypeRange{}, + newPrefetchOperands, + prefetchOp->getAttrs()); + xegpu::removeLayoutAttrs(prefetchOp); rewriter.eraseOp(prefetchOp); return success(); } @@ -795,17 +795,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { /// region. This will simply move the barrier op outside of the warp op. struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; - LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { auto yield = cast( - subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); Operation *lastNode = yield->getPrevNode(); // The last node must be a gpu::BarrierOp. auto barrierOp = dyn_cast_or_null(lastNode); if (!barrierOp) return failure(); // Move the barrier op outside of the warp op. - rewriter.setInsertionPointAfter(subgroupOp); + rewriter.setInsertionPointAfter(warpOp); rewriter.create( barrierOp.getLoc(), barrierOp->getResultTypes(), barrierOp->getOperands(), barrierOp->getAttrs()); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 6b85a66a8bd36..370d149ee55af 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -184,6 +184,31 @@ void xegpu::setLayoutAttrs(Operation *op, }); } +template +void xegpu::removeLayoutAttr(const T &operandOrResult) { + Operation *owner = operandOrResult.getOwner(); + std::string name = xegpu::getLayoutName(operandOrResult); + if (owner->hasAttrOfType(name)) + owner->removeAttr(name); +} + +// Explicit instantiation for OpResult +template void +xegpu::removeLayoutAttr(const mlir::OpResult &result); + +// Explicit instantiation for OpOperand +template void +xegpu::removeLayoutAttr(const mlir::OpOperand &operand); + +void xegpu::removeLayoutAttrs(Operation *op) { + op->walk([&](Operation *nestOp) { + for (OpOperand &opr : nestOp->getOpOperands()) + removeLayoutAttr(opr); + for (OpResult result : nestOp->getOpResults()) + removeLayoutAttr(result); + }); +} + SmallVector xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef shape) { diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 0bfbc4a35c03b..e78ae4a17710b 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -95,10 +95,10 @@ gpu.module @test { // ----- // CHECK-LABEL: gpu.func @load_dpas_store // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) { -// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> -// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16> +// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> +// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> // CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32> // CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> @@ -120,10 +120,10 @@ gpu.module @test { // ----- // CHECK-LABEL: gpu.func @load_dpas_postop_store // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) { -// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> -// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16> +// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> +// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> // CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32> // CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32> // CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>