From fd09d122269c0f53a6340e9cafc32fff95f711eb Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 19 Aug 2025 17:28:28 +0000 Subject: [PATCH 01/12] refactor --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h | 2 + .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 45 ++++- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 8 +- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 13 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 +- .../Transforms/XeGPUWgToSgDistribute.cpp | 157 ++++++++++++++++-- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 68 ++++++++ .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 4 +- 8 files changed, 268 insertions(+), 33 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h index 3592da4c46364..ce33da9632c2b 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h @@ -11,6 +11,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -23,6 +24,7 @@ namespace mlir { namespace xegpu { class TensorDescType; +class DistributLayoutAttrInterface; class LayoutAttr; class SliceAttr; } // namespace xegpu diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index a94987885c9e0..adfd8bae75a5a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -175,22 +175,31 @@ def XeGPU_FenceScopeAttr: let assemblyFormat = "$value"; } -def LayoutTrait: AttrInterface<"LayoutTrait"> { +def DistributLayoutAttrInterface: AttrInterface<"DistributLayoutAttrInterface"> { let cppNamespace = "::mlir::xegpu"; let description = [{ Common trait for all XeGPU layouts. }]; let methods = [ + InterfaceMethod<"Check the availability of workgroup level layouts", + "bool", + "isWgLayout">, InterfaceMethod<"Get the rank of attribute", "int64_t", "getRank">, + InterfaceMethod<"Get the num of effective subgroups", + "int64_t", + "getNumSubgroups">, InterfaceMethod<"Get the SgLayout field of the attribute as integer array", "std::optional>", "getSgLayoutAsInt">, InterfaceMethod<"Get the SgData field of the attribute as integer array", "std::optional>", "getSgDataAsInt">, + InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData", + "xegpu::DistributLayoutAttrInterface", + "dropSgLayoutAndData">, InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional indices based on the effective subgroup layout.}], "FailureOr>", @@ -206,7 +215,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> { ]; } -def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> { +def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributLayoutAttrInterface]> { let summary = [{ Describes the data distribution to subgroups and work-items for a tensor specified by the tensor descriptor. @@ -346,6 +355,13 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> { return 0; } + int64_t getNumSubgroups() { + std::optional> sgLayout = getSgLayoutAsInt(); + if (sgLayout.has_value()) + return computeProduct(*sgLayout); + return 0; + } + LayoutAttr dropSgLayoutAndData() { // avoid every field of the attribute is nullptr, which may lead to segment fault if (!getInstData() && !getLaneLayout()) @@ -393,7 +409,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> { } -def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> { +def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributLayoutAttrInterface]> { let summary = [{Describes the data distribution and sharing among subgroups or work-items.}]; let description = [{ @@ -420,7 +436,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> { }]; let parameters = (ins - "xegpu::LayoutTrait": $parent, + "xegpu::DistributLayoutAttrInterface": $parent, "DenseI64ArrayAttr": $dims ); @@ -450,6 +466,13 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> { return parent.isSgLayout(); } + int64_t getNumSubgroups() { + std::optional> sgLayout = getSgLayoutAsInt(); + if (sgLayout.has_value()) + return computeProduct(*sgLayout); + return 0; + } + /// Returns the SgLayout of the attribute, computed by applying /// the slice dimensions to the underlying LayoutAttr. std::optional> getSgLayoutAsInt() const { @@ -474,6 +497,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> { return std::nullopt; } + SliceAttr dropSgLayoutAndData() { + SliceAttr attr = flatten(); + auto parent = dyn_cast(attr.getParent()); + parent = parent.dropSgLayoutAndData(); + return SliceAttr::get(getContext(), parent, attr.getDims()); + } + + SliceAttr dropInstData() { + SliceAttr attr = flatten(); + auto parent = dyn_cast(attr.getParent()); + parent = parent.dropInstData(); + return SliceAttr::get(getContext(), parent, attr.getDims()); + } + /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr /// #xegpu.slice<#xegpu.slice<#xegpu.layout, dims = [0]>, dims = [0]> /// it will coalese two slice operations and return a simplified SliceAttr diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index eb54d6887681d..deea44cd14db0 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1150,7 +1150,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$layout + OptionalAttr:$layout ); let results = (outs XeGPU_ValueType:$res); let assemblyFormat = [{ @@ -1175,7 +1175,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, let builders = [ OpBuilder<(ins "Type":$res, "TypedValue": $mem_desc, - "llvm::ArrayRef": $offsets, "LayoutTrait": $layout)>, + "llvm::ArrayRef": $offsets, "DistributLayoutAttrInterface": $layout)>, ]; let extraClassDeclaration = [{ SmallVector getMixedOffsets() { @@ -1194,7 +1194,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$layout + OptionalAttr:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands)}]; @@ -1213,7 +1213,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, }]; let builders = [ OpBuilder<(ins "Value" : $data, "TypedValue": $mem_desc, - "llvm::ArrayRef": $offsets, "LayoutTrait": $layout)>, + "llvm::ArrayRef": $offsets, "DistributLayoutAttrInterface": $layout)>, ]; let extraClassDeclaration = [{ SmallVector getMixedOffsets() { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 8ea8cb1f45972..9f6e498854c18 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -290,8 +290,8 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, return affine::delinearizeIndex(builder, loc, linearId, dims); } -/// Implements LayoutTrait::getOffsets to generate instructions for -/// computing multi-dimensional offsets when distributed by LayoutAttr. +/// Implements DistributLayoutAttrInterface::getOffsets to generate instructions +/// for computing multi-dimensional offsets when distributed by LayoutAttr. FailureOr>> LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape) { @@ -322,7 +322,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, //===----------------------------------------------------------------------===// LogicalResult SliceAttr::verify(llvm::function_ref emitError, - xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) { + xegpu::DistributLayoutAttrInterface parent, + DenseI64ArrayAttr dims) { if (!parent || !dims) return emitError() << "expected parent layout and dims attribute"; @@ -340,7 +341,7 @@ SliceAttr::verify(llvm::function_ref emitError, } SliceAttr SliceAttr::flatten() const { - xegpu::LayoutTrait parent = getParent(); + xegpu::DistributLayoutAttrInterface parent = getParent(); SmallVector slicedDims({getDims()}); while (auto sliceAttr = dyn_cast(parent)) { @@ -375,8 +376,8 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, return parent.delinearizeSubgroupId(builder, loc, linearId); } -/// Implements LayoutTrait::getOffsets to generate instructions for -/// computing multi-dimensional offsets when distributed by SliceAttr. +/// Implements DistributLayoutAttrInterface::getOffsets to generate instructions +/// for computing multi-dimensional offsets when distributed by SliceAttr. FailureOr>> SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape) { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 906c71d8b8dad..05a3604ae2b43 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, TypedValue memDesc, llvm::ArrayRef offsets, - LayoutTrait layout) { + DistributLayoutAttrInterface layout) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() { void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, TypedValue memDesc, llvm::ArrayRef offsets, - LayoutTrait layout) { + DistributLayoutAttrInterface layout) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 8f1208e77ca5d..39077c1fb64b6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -55,17 +55,16 @@ static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, } static std::pair, int> -getSgShapeAndCount(ArrayRef shape, xegpu::LayoutAttr layout) { +getSgShapeAndCount(ArrayRef shape, + xegpu::DistributLayoutAttrInterface layout) { int count = 1; SmallVector sgShape(shape); - if (layout && layout.isWgLayout()) { - DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout(); - auto sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); - if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) - sgShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); - else - sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape); + SmallVector sgLayout = layout.getSgLayoutAsInt().value(); + if (auto maybeSgData = layout.getSgDataAsInt()) + sgShape = *maybeSgData; + else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout)) + sgShape = *maybeDerivedSgData; SmallVector distUnit = computeElementwiseMul(sgLayout, sgShape); // Clamp distUnit to the original shape to handle cases where data is // shared among subgroups, which may cause distUnit to exceed the original @@ -723,8 +722,8 @@ struct WgToSgElementwiseOp : public ConversionPattern { // is lowered to: // #a = #xegpu.layout // #b = #xegpu.layout -// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32> -// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32> +// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32> +// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32> // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32> // clang-format on struct WgToSgConvertLayoutOp @@ -884,6 +883,123 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } }; +// A callback funtion type used to create new load/store_matrix ops +using CreatorFuncType = + llvm::function_ref baseOffsets, + SmallVector> &descOffsets)>; + +/// Utility helper for distributing logic shared by load_matrix and store_matrix +/// operations. +template ::value>> +LogicalResult distributeMatrixOp( + ConversionPatternRewriter &rewriter, + typename OpConversionPattern::OneToNOpAdaptor adaptor, OpType op, + ArrayRef wgShape, CreatorFuncType callback) { + Location loc = op.getLoc(); + auto layout = op.getLayoutAttr(); + if (!layout || !layout.isWgLayout()) + return failure(); + + Value sgId = rewriter.create(loc, /*upper_bound=*/nullptr); + + // adjust the linearId if the range specifier is present + int64_t startOfRange = -1, endOfRange = -1; + bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange); + if (sgIdRangeSpecified) { + if (layout.getNumSubgroups() != endOfRange - startOfRange) + return rewriter.notifyMatchFailure( + op, "sg_layout size must match the sg_id_range"); + Value startOfRangeVal = + rewriter.create(loc, startOfRange); + sgId = rewriter.create(loc, startOfRangeVal, sgId); + } + + auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(maybeMdescOffsets)) + return failure(); + + SmallVector wgOffsets = op.getMixedOffsets(); + callback(wgOffsets, *maybeMdescOffsets); + return success(); +} + +static SmallVector add(ConversionPatternRewriter &rewriter, + Location loc, ArrayRef lhs, + ArrayRef rhs) { + return llvm::map_to_vector( + llvm::zip_equal(lhs, rhs), [&](auto p) -> OpFoldResult { + auto l = getValueOrCreateConstantIndexOp(rewriter, loc, std::get<0>(p)); + auto r = getValueOrCreateConstantIndexOp(rewriter, loc, std::get<1>(p)); + return rewriter.create(loc, l, r).getResult(); + }); +} + +struct WgToSgLoadMatrixOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + VectorType valueTy = op.getRes().getType(); + ArrayRef wgShape = valueTy.getShape(); + Type elemTy = valueTy.getElementType(); + + // the call back function for creating new LoadMatrixOps, + // the baseOffsets is the origial offsets of the op, and + // descOffsets is the relative offsets to the mem_desc accessed + // by each subgroup op. + auto callback = [&](ArrayRef baseOffsets, + SmallVector> descOffsets) { + auto layout = op.getLayoutAttr(); + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResTy = VectorType::get(sgShape, elemTy); + + SmallVector newOps; + for (auto offsets : descOffsets) { + SmallVector sgOffsets = + add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets)); + auto newOp = rewriter.create( + loc, newResTy, op.getMemDesc(), sgOffsets, + layout.dropSgLayoutAndData()); + newOps.push_back(newOp); + } + rewriter.replaceOpWithMultiple(op, {newOps}); + }; + + return distributeMatrixOp(rewriter, adaptor, op, wgShape, callback); + } +}; + +struct WgToSgStoreMatrixOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + VectorType valueTy = op.getData().getType(); + ArrayRef wgShape = valueTy.getShape(); + + // the call back function for creating new StoreMatrixOps, + // the baseOffsets is the origial offsets of the op, and + // descOffsets is the relative offsets to the mem_desc accessed + // by each subgroup op. + auto callback = [&](ArrayRef baseOffsets, + SmallVector> descOffsets) { + auto layout = op.getLayoutAttr(); + for (auto [v, descOffsets] : llvm::zip(adaptor.getData(), descOffsets)) { + SmallVector sgOffsets = + add(rewriter, loc, baseOffsets, getAsOpFoldResult(descOffsets)); + rewriter.create( + loc, v, op.getMemDesc(), sgOffsets, layout.dropSgLayoutAndData()); + } + rewriter.eraseOp(op); + }; + return distributeMatrixOp(rewriter, adaptor, op, wgShape, callback); + } +}; + } // namespace namespace mlir { @@ -895,7 +1011,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, - WgToSgArithConstantOp>(patterns.getContext()); + WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -985,7 +1102,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return xegpu::TensorDescType(); }; - auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { + auto isLegal = [&](xegpu::DistributLayoutAttrInterface layout) -> bool { return !layout || !layout.isWgLayout(); }; @@ -1002,9 +1119,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp( - [=](vector::BroadcastOp op) -> bool { - return isLegal(xegpu::getLayoutAttr(op.getResult())); + target.addDynamicallyLegalOp( + [=](xegpu::LoadMatrixOp op) -> bool { + return isLegal(op.getLayoutAttr()); + }); + + target.addDynamicallyLegalOp( + [=](xegpu::StoreMatrixOp op) -> bool { + return isLegal(op.getLayoutAttr()); }); target.addDynamicallyLegalOp( @@ -1015,6 +1137,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(xegpu::getLayoutAttr(op.getResult())); }); + target.addDynamicallyLegalOp( + [=](vector::BroadcastOp op) -> bool { + return isLegal(xegpu::getLayoutAttr(op.getResult())); + }); + target.addDynamicallyLegalOp( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index f4a49da71605f..5f851e9003a0e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -372,4 +372,72 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<1.0> : vector<256x128xf32> gpu.return } + + // CHECK-LABEL: distribute_load_matrix + // CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3> + gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) { + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index + //CHECK: [[c2:%.+]] = arith.constant 2 : index + //CHECK: [[c4:%.+]] = arith.constant 4 : index + //CHECK: [[c4_0:%.+]] = arith.constant 4 : index + //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]] + //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]] + //CHECK: [[c32:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]] + //CHECK: [[c32_1:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]] + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[c0_1:%.+]] = arith.constant 0 : index + //CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index + //CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index + //CHECK: [[c64:%.+]] = arith.constant 64 : index + //CHECK: [[mod_y:%.+]] = index.remu [[l_off_y_0]], [[c64]] + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x_0]], [[c128]] + //CHECK: [[c0_2:%.+]] = arith.constant 0 : index + //CHECK: [[off_y:%.+]] = index.add [[c0_2]], [[mod_y]] + //CHECK: [[c0_3:%.+]] = arith.constant 0 : index + //CHECK: [[off_x:%.+]] = index.add [[c0_3]], [[mod_x]] + //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32> + %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32> + gpu.return + } + + //CHECK-LABEL: distribute_store_matrix + //CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3> + gpu.func @distribute_store_matrix(%arg0 : memref<32768xi8, 3>) { + //CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32> + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index + //CHECK: [[c2:%.+]] = arith.constant 2 : index + //CHECK: [[c4:%.+]] = arith.constant 4 : index + //CHECK: [[c4_0:%.+]] = arith.constant 4 : index + //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]] + //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]] + //CHECK: [[c32:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]] + //CHECK: [[c32_1:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]] + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[c0_2:%.+]] = arith.constant 0 : index + //CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index + //CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index + //CHECK: [[c64:%.+]] = arith.constant 64 : index + //CHECK: [[mod_y:%.+]] = index.remu [[l_off_y]], [[c64]] + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x]], [[c128]] + //CHECK: [[c0_3:%.+]] = arith.constant 0 : index + //CHECK: [[off_y:%.+]] = index.add [[c0_3]], [[mod_y]] + //CHECK: [[c0_4:%.+]] = arith.constant 0 : index + //CHECK: [[off_x:%.+]] = index.add [[c0_4]], [[mod_x]] + //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index + %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<1.0> : vector<64x128xf32> + %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32> + + gpu.return + } + } diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 58962714b7864..d94d285b1105d 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -156,8 +156,8 @@ struct TestXeGPUUnrollingPatterns #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") // Test pattern for distributing vector::StepOp from workgroup to subgroup. -// Validates LayoutTrait interfaces for offset computation abstraction between -// LayoutAttr and SliceAttr. +// Validates DistributLayoutAttrInterface interfaces for offset computation +// abstraction between LayoutAttr and SliceAttr. class TestStepOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; From 6fc6ec7b69d6d7d1e86ec05e876f80876fe46e6c Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 19 Aug 2025 18:35:43 +0000 Subject: [PATCH 02/12] refactor createNdOp --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h | 2 +- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 10 +- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 12 +- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 14 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 +- .../Transforms/XeGPUWgToSgDistribute.cpp | 209 ++++++++---------- .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 8 +- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 16 +- .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 2 +- 9 files changed, 134 insertions(+), 143 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h index ce33da9632c2b..1d152f0c9ca9a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h @@ -24,7 +24,7 @@ namespace mlir { namespace xegpu { class TensorDescType; -class DistributLayoutAttrInterface; +class DistributeLayoutAttrInterface; class LayoutAttr; class SliceAttr; } // namespace xegpu diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index adfd8bae75a5a..de86141ad006a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -175,7 +175,7 @@ def XeGPU_FenceScopeAttr: let assemblyFormat = "$value"; } -def DistributLayoutAttrInterface: AttrInterface<"DistributLayoutAttrInterface"> { +def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"> { let cppNamespace = "::mlir::xegpu"; let description = [{ Common trait for all XeGPU layouts. @@ -198,7 +198,7 @@ def DistributLayoutAttrInterface: AttrInterface<"DistributLayoutAttrInterface"> "std::optional>", "getSgDataAsInt">, InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData", - "xegpu::DistributLayoutAttrInterface", + "xegpu::DistributeLayoutAttrInterface", "dropSgLayoutAndData">, InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional indices based on the effective subgroup layout.}], @@ -215,7 +215,7 @@ def DistributLayoutAttrInterface: AttrInterface<"DistributLayoutAttrInterface"> ]; } -def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributLayoutAttrInterface]> { +def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterface]> { let summary = [{ Describes the data distribution to subgroups and work-items for a tensor specified by the tensor descriptor. @@ -409,7 +409,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributLayoutAttrInterfa } -def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributLayoutAttrInterface]> { +def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface]> { let summary = [{Describes the data distribution and sharing among subgroups or work-items.}]; let description = [{ @@ -436,7 +436,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributLayoutAttrInterface] }]; let parameters = (ins - "xegpu::DistributLayoutAttrInterface": $parent, + "xegpu::DistributeLayoutAttrInterface": $parent, "DenseI64ArrayAttr": $dims ); diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index deea44cd14db0..3ba9eaa4a66da 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -232,6 +232,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface return static_cast(MemorySpace::Global); } + xegpu::DistributeLayoutAttrInterface getLayoutAttr() { + return dyn_cast_if_present(getType().getLayout()); + } + }]; } @@ -1150,7 +1154,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$layout + OptionalAttr:$layout ); let results = (outs XeGPU_ValueType:$res); let assemblyFormat = [{ @@ -1175,7 +1179,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, let builders = [ OpBuilder<(ins "Type":$res, "TypedValue": $mem_desc, - "llvm::ArrayRef": $offsets, "DistributLayoutAttrInterface": $layout)>, + "llvm::ArrayRef": $offsets, "DistributeLayoutAttrInterface": $layout)>, ]; let extraClassDeclaration = [{ SmallVector getMixedOffsets() { @@ -1194,7 +1198,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$layout + OptionalAttr:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands)}]; @@ -1213,7 +1217,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, }]; let builders = [ OpBuilder<(ins "Value" : $data, "TypedValue": $mem_desc, - "llvm::ArrayRef": $offsets, "DistributLayoutAttrInterface": $layout)>, + "llvm::ArrayRef": $offsets, "DistributeLayoutAttrInterface": $layout)>, ]; let extraClassDeclaration = [{ SmallVector getMixedOffsets() { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 9f6e498854c18..de118b7faea4d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -290,8 +290,9 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, return affine::delinearizeIndex(builder, loc, linearId, dims); } -/// Implements DistributLayoutAttrInterface::getOffsets to generate instructions -/// for computing multi-dimensional offsets when distributed by LayoutAttr. +/// Implements DistributeLayoutAttrInterface::getOffsets to generate +/// instructions for computing multi-dimensional offsets when distributed by +/// LayoutAttr. FailureOr>> LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape) { @@ -322,7 +323,7 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, //===----------------------------------------------------------------------===// LogicalResult SliceAttr::verify(llvm::function_ref emitError, - xegpu::DistributLayoutAttrInterface parent, + xegpu::DistributeLayoutAttrInterface parent, DenseI64ArrayAttr dims) { if (!parent || !dims) return emitError() << "expected parent layout and dims attribute"; @@ -341,7 +342,7 @@ SliceAttr::verify(llvm::function_ref emitError, } SliceAttr SliceAttr::flatten() const { - xegpu::DistributLayoutAttrInterface parent = getParent(); + xegpu::DistributeLayoutAttrInterface parent = getParent(); SmallVector slicedDims({getDims()}); while (auto sliceAttr = dyn_cast(parent)) { @@ -376,8 +377,9 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, return parent.delinearizeSubgroupId(builder, loc, linearId); } -/// Implements DistributLayoutAttrInterface::getOffsets to generate instructions -/// for computing multi-dimensional offsets when distributed by SliceAttr. +/// Implements DistributeLayoutAttrInterface::getOffsets to generate +/// instructions for computing multi-dimensional offsets when distributed by +/// SliceAttr. FailureOr>> SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape) { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 05a3604ae2b43..0e22af900daf1 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, TypedValue memDesc, llvm::ArrayRef offsets, - DistributLayoutAttrInterface layout) { + DistributeLayoutAttrInterface layout) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() { void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, TypedValue memDesc, llvm::ArrayRef offsets, - DistributLayoutAttrInterface layout) { + DistributeLayoutAttrInterface layout) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 39077c1fb64b6..ca1209e776d0e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -56,7 +56,7 @@ static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, static std::pair, int> getSgShapeAndCount(ArrayRef shape, - xegpu::DistributLayoutAttrInterface layout) { + xegpu::DistributeLayoutAttrInterface layout) { int count = 1; SmallVector sgShape(shape); if (layout && layout.isWgLayout()) { @@ -76,6 +76,72 @@ getSgShapeAndCount(ArrayRef shape, return std::make_pair(sgShape, count); } +// An util helper to generate elementwise addition ops for index computing. +// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match. +// left-alignment is performed. +static SmallVector add(ConversionPatternRewriter &rewriter, + Location loc, ArrayRef lhs, + ArrayRef rhs) { + SmallVector reversedResult; + auto l = lhs.rbegin(); + auto r = rhs.rbegin(); + for (; l != lhs.rend() || r != rhs.rend(); ++l, ++r) { + if (l == lhs.rend()) { + reversedResult.push_back(*r); + } else if (r == rhs.rend()) { + reversedResult.push_back(*l); + } else { + auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, *l); + auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, *r); + auto add = rewriter.createOrFold(loc, lval, rval); + reversedResult.push_back(add); + } + } + return llvm::to_vector(llvm::reverse(reversedResult)); +} + +// A callback funtion type used to create new load/store_matrix ops +using CreatorFuncType = + llvm::function_ref baseOffsets, + SmallVector> &descOffsets)>; + +/// Utility helper for distributing logic shared by operations with offsets +template ::value>> +static LogicalResult +distributeOp(ConversionPatternRewriter &rewriter, + typename OpConversionPattern::OneToNOpAdaptor adaptor, + OpType op, ArrayRef wgShape, CreatorFuncType callback) { + Location loc = op.getLoc(); + auto layout = op.getLayoutAttr(); + if (!layout || !layout.isWgLayout()) + return failure(); + + Value sgId = rewriter.create(loc, /*upper_bound=*/nullptr); + + // adjust the linearId if the range specifier is present + int64_t startOfRange = -1, endOfRange = -1; + bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange); + if (sgIdRangeSpecified) { + if (layout.getNumSubgroups() != endOfRange - startOfRange) + return rewriter.notifyMatchFailure( + op, "sg_layout size must match the sg_id_range"); + Value startOfRangeVal = + rewriter.create(loc, startOfRange); + sgId = rewriter.create(loc, sgId, startOfRangeVal); + } + + auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(maybeMdescOffsets)) + return failure(); + + SmallVector wgOffsets = op.getMixedOffsets(); + callback(wgOffsets, *maybeMdescOffsets); + return success(); +} + /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with /// appropriate values for the subgroup. @@ -136,71 +202,35 @@ struct WgToSgCreateNdOp : public OpConversionPattern { Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); - auto layout = dyn_cast(tdescTy.getLayout()); - if (!layout) - return failure(); - Type elemTy = tdescTy.getElementType(); ArrayRef wgShape = tdescTy.getShape(); - // sgLayout must be present for workgroup-level distribution. - SmallVector sgLayout; - if (auto sgLayoutAttr = layout.getSgLayout()) - sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); - else - return rewriter.notifyMatchFailure( - op, "sgLayout attribute is required in layout"); - - // Get the subgroup ID - Value linearSgId = - gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - - int64_t startOfRange = -1, endOfRange = -1; - bool sgIdRangeSpecified = - isSgIdRangeSpecified(op, startOfRange, endOfRange); - - if (sgIdRangeSpecified) { - int64_t sgCount = endOfRange - startOfRange; - if (computeProduct(sgLayout) != sgCount) - return rewriter.notifyMatchFailure( - op, "sg_layout size must match the sg_id_range"); - // Subtract startOfRange from the original subgroup id to get - // the adjusted sg id - Value startOfRangeVal = - arith::ConstantIndexOp::create(rewriter, loc, startOfRange); - linearSgId = - rewriter.createOrFold(loc, linearSgId, startOfRangeVal); - } + Type elemTy = tdescTy.getElementType(); - auto maybeTdescOffsets = - layout.getOffsets(rewriter, loc, linearSgId, wgShape); - if (failed(maybeTdescOffsets)) - return failure(); + // the call back function for creating new CreateNdOps, + // the baseOffsets is the origial offsets of the op, and + // descOffsets is the relative offsets to the mem_desc accessed + // by each subgroup op. + auto callback = [&](ArrayRef baseOffsets, + SmallVector> descOffsets) { + xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + auto newTdescTy = xegpu::TensorDescType::get( + ctx, sgShape, elemTy, tdescTy.getEncoding(), + layout.dropSgLayoutAndData()); - SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; - xegpu::TensorDescType newTdescTy = - xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), - layout.dropSgLayoutAndData()); + SmallVector newOps; + for (auto offsets : descOffsets) { + SmallVector sgOffsets = + add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets)); + auto newOp = xegpu::CreateNdDescOp::create( + rewriter, loc, newTdescTy, op.getSource(), sgOffsets, + op.getMixedSizes(), op.getMixedStrides()); - SmallVector newCreateNdOps; - SmallVector origOffsets = op.getMixedOffsets(); - - for (auto tdescOffsets : *maybeTdescOffsets) { - SmallVector sgOffsets; - size_t rank = tdescOffsets.size(); - for (size_t i = 0; i < rank; i++) { - size_t idx = origOffsets.size() - rank + i; - Value add = rewriter.createOrFold( - loc, tdescOffsets[i], - getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx])); - sgOffsets.push_back(add); + newOps.push_back(newOp); } + rewriter.replaceOpWithMultiple(op, {newOps}); + }; - auto newOp = xegpu::CreateNdDescOp::create( - rewriter, loc, newTdescTy, op.getSource(), sgOffsets, - op.getMixedSizes(), op.getMixedStrides()); - newCreateNdOps.push_back(newOp); - } - rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); - return success(); + return distributeOp(rewriter, adaptor, op, wgShape, callback); } }; @@ -883,59 +913,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } }; -// A callback funtion type used to create new load/store_matrix ops -using CreatorFuncType = - llvm::function_ref baseOffsets, - SmallVector> &descOffsets)>; - -/// Utility helper for distributing logic shared by load_matrix and store_matrix -/// operations. -template ::value>> -LogicalResult distributeMatrixOp( - ConversionPatternRewriter &rewriter, - typename OpConversionPattern::OneToNOpAdaptor adaptor, OpType op, - ArrayRef wgShape, CreatorFuncType callback) { - Location loc = op.getLoc(); - auto layout = op.getLayoutAttr(); - if (!layout || !layout.isWgLayout()) - return failure(); - - Value sgId = rewriter.create(loc, /*upper_bound=*/nullptr); - - // adjust the linearId if the range specifier is present - int64_t startOfRange = -1, endOfRange = -1; - bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange); - if (sgIdRangeSpecified) { - if (layout.getNumSubgroups() != endOfRange - startOfRange) - return rewriter.notifyMatchFailure( - op, "sg_layout size must match the sg_id_range"); - Value startOfRangeVal = - rewriter.create(loc, startOfRange); - sgId = rewriter.create(loc, startOfRangeVal, sgId); - } - - auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); - if (failed(maybeMdescOffsets)) - return failure(); - - SmallVector wgOffsets = op.getMixedOffsets(); - callback(wgOffsets, *maybeMdescOffsets); - return success(); -} - -static SmallVector add(ConversionPatternRewriter &rewriter, - Location loc, ArrayRef lhs, - ArrayRef rhs) { - return llvm::map_to_vector( - llvm::zip_equal(lhs, rhs), [&](auto p) -> OpFoldResult { - auto l = getValueOrCreateConstantIndexOp(rewriter, loc, std::get<0>(p)); - auto r = getValueOrCreateConstantIndexOp(rewriter, loc, std::get<1>(p)); - return rewriter.create(loc, l, r).getResult(); - }); -} - struct WgToSgLoadMatrixOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -968,7 +945,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern { rewriter.replaceOpWithMultiple(op, {newOps}); }; - return distributeMatrixOp(rewriter, adaptor, op, wgShape, callback); + return distributeOp(rewriter, adaptor, op, wgShape, callback); } }; @@ -996,7 +973,7 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern { } rewriter.eraseOp(op); }; - return distributeMatrixOp(rewriter, adaptor, op, wgShape, callback); + return distributeOp(rewriter, adaptor, op, wgShape, callback); } }; @@ -1102,7 +1079,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return xegpu::TensorDescType(); }; - auto isLegal = [&](xegpu::DistributLayoutAttrInterface layout) -> bool { + auto isLegal = [&](xegpu::DistributeLayoutAttrInterface layout) -> bool { return !layout || !layout.isWgLayout(); }; diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index e5cc65e6bd3d7..7dcdcca070ac9 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -30,9 +30,13 @@ gpu.module @test_round_robin_assignment { //CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index //CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index //CHECK: [[C128:%.+]] = arith.constant 128 : index - //CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]] + //CHECK: [[modY:%.+]] = index.remu [[ADDY]], [[C128]] //CHECK: [[C64_2:%.+]] = arith.constant 64 : index - //CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C64_2]] + //CHECK: [[modX:%.+]] = index.remu [[ADDX]], [[C64_2]] + //CHECK: [[C0_3:%.+]] = arith.constant 0 : index + //CHECK: [[offX:%.+]] = index.add [[modX]], [[C0_3]] + //CHECK: [[C0_4:%.+]] = arith.constant 0 : index + //CHECK: [[offY:%.+]] = index.add [[modY]], [[C0_4]] //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32> %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout> diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 5f851e9003a0e..bdda77a69f22e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -17,9 +17,13 @@ gpu.module @test_1_1_assignment { //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index //CHECK: [[C256:%.+]] = arith.constant 256 : index - //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]] + //CHECK: [[MODY:%.+]] = index.remu [[UY]], [[C256]] //CHECK: [[C128:%.+]] = arith.constant 128 : index - //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]] + //CHECK: [[MODX:%.+]] = index.remu [[UX]], [[C128]] + //CHECK: [[C0_3:%.+]] = arith.constant 0 : index + //CHECK: [[X:%.+]] = index.add [[MODX]], [[C0_3]] + //CHECK: [[C0_4:%.+]] = arith.constant 0 : index + //CHECK: [[Y:%.+]] = index.add [[MODY]], [[C0_4]] //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout> %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> @@ -396,9 +400,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { //CHECK: [[c128:%.+]] = arith.constant 128 : index //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x_0]], [[c128]] //CHECK: [[c0_2:%.+]] = arith.constant 0 : index - //CHECK: [[off_y:%.+]] = index.add [[c0_2]], [[mod_y]] + //CHECK: [[off_x:%.+]] = index.add [[mod_x]], [[c0_2]] //CHECK: [[c0_3:%.+]] = arith.constant 0 : index - //CHECK: [[off_x:%.+]] = index.add [[c0_3]], [[mod_x]] + //CHECK: [[off_y:%.+]] = index.add [[mod_y]], [[c0_3]] //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32> %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32> @@ -429,9 +433,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { //CHECK: [[c128:%.+]] = arith.constant 128 : index //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x]], [[c128]] //CHECK: [[c0_3:%.+]] = arith.constant 0 : index - //CHECK: [[off_y:%.+]] = index.add [[c0_3]], [[mod_y]] + //CHECK: [[off_x:%.+]] = index.add [[mod_x]], [[c0_3]] //CHECK: [[c0_4:%.+]] = arith.constant 0 : index - //CHECK: [[off_x:%.+]] = index.add [[c0_4]], [[mod_x]] + //CHECK: [[off_y:%.+]] = index.add [[mod_y]], [[c0_4]] //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<1.0> : vector<64x128xf32> %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index d94d285b1105d..8d2fb85655c72 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -156,7 +156,7 @@ struct TestXeGPUUnrollingPatterns #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") // Test pattern for distributing vector::StepOp from workgroup to subgroup. -// Validates DistributLayoutAttrInterface interfaces for offset computation +// Validates DistributeLayoutAttrInterface interfaces for offset computation // abstraction between LayoutAttr and SliceAttr. class TestStepOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; From 9af1f7f417c5c4ee0e1a47830d8f061286cfe9e0 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 19 Aug 2025 20:20:07 +0000 Subject: [PATCH 03/12] fix typo --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index ca1209e776d0e..d31f0f1a75c51 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -752,8 +752,8 @@ struct WgToSgElementwiseOp : public ConversionPattern { // is lowered to: // #a = #xegpu.layout // #b = #xegpu.layout -// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32> -// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32> +// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32> +// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32> // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32> // clang-format on struct WgToSgConvertLayoutOp From 93acad2807a0fd9ea9e3d5f594afc681dfb50840 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 20 Aug 2025 16:49:05 +0000 Subject: [PATCH 04/12] cleanup --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 66 +++ .../Transforms/XeGPUWgToSgDistribute.cpp | 384 ++++++------------ .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 8 +- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 58 +++ mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 76 +--- 5 files changed, 249 insertions(+), 343 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 3ba9eaa4a66da..3182552288ca6 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -236,6 +236,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface return dyn_cast_if_present(getType().getLayout()); } + ArrayRef getDistributeShape() { + return getTensorDescShape(); + } + }]; } @@ -266,6 +270,23 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { xegpu::TensorDescType getTensorDescType() { return getTensorDesc().getType(); } + + SmallVector getMixedOffsets() { + auto statics = getConstOffsets().value_or(SmallVector()); + auto dynamics = getOffsets(); + if (statics.size() == 0 && dynamics.size() == 0) + return {}; + return getMixedValues(statics, dynamics, getContext()); + } + + xegpu::DistributeLayoutAttrInterface getLayoutAttr() { + return dyn_cast_if_present(getTensorDescType().getLayout()); + } + + ArrayRef getDistributeShape() { + return getTensorDescType().getShape(); + } + }]; let assemblyFormat = [{ @@ -347,6 +368,24 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ xegpu::TensorDescType getTensorDescType() { return getTensorDesc().getType(); } + + SmallVector getMixedOffsets() { + auto statics = getConstOffsets().value_or(SmallVector()); + auto dynamics = getOffsets(); + if (statics.size() == 0 && dynamics.size() == 0) + return {}; + return getMixedValues(statics, dynamics, getContext()); + } + + xegpu::DistributeLayoutAttrInterface getLayoutAttr() { + return dyn_cast_if_present(getTensorDescType().getLayout()); + } + + ArrayRef getDistributeShape() { + return getTensorDescType().getShape(); + } + + }]; let assemblyFormat = [{ @@ -421,6 +460,23 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ xegpu::TensorDescType getTensorDescType() { return getTensorDesc().getType(); } + + SmallVector getMixedOffsets() { + auto statics = getConstOffsets().value_or(SmallVector()); + auto dynamics = getOffsets(); + if (statics.size() == 0 && dynamics.size() == 0) + return {}; + return getMixedValues(statics, dynamics, getContext()); + } + + xegpu::DistributeLayoutAttrInterface getLayoutAttr() { + return dyn_cast_if_present(getTensorDescType().getLayout()); + } + + ArrayRef getDistributeShape() { + return getTensorDescType().getShape(); + } + }]; let assemblyFormat = [{ @@ -644,6 +700,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { xegpu::TensorDescType getTensorDescType() { return dyn_cast(getSourceType()); } + }]; let assemblyFormat = [{ @@ -1185,6 +1242,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, SmallVector getMixedOffsets() { return getMixedValues(getConstOffsets(), getOffsets(), getContext()); } + + ArrayRef getDistributeShape() { + return getRes().getType().getShape(); + } }]; let hasVerifier = 1; @@ -1223,6 +1284,11 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, SmallVector getMixedOffsets() { return getMixedValues(getConstOffsets(), getOffsets(), getContext()); } + + ArrayRef getDistributeShape() { + return getData().getType().getShape(); + } + }]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index d31f0f1a75c51..76bda64ffac1e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -79,43 +79,42 @@ getSgShapeAndCount(ArrayRef shape, // An util helper to generate elementwise addition ops for index computing. // lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match. // left-alignment is performed. -static SmallVector add(ConversionPatternRewriter &rewriter, - Location loc, ArrayRef lhs, - ArrayRef rhs) { - SmallVector reversedResult; - auto l = lhs.rbegin(); - auto r = rhs.rbegin(); - for (; l != lhs.rend() || r != rhs.rend(); ++l, ++r) { - if (l == lhs.rend()) { - reversedResult.push_back(*r); - } else if (r == rhs.rend()) { - reversedResult.push_back(*l); - } else { - auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, *l); - auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, *r); - auto add = rewriter.createOrFold(loc, lval, rval); - reversedResult.push_back(add); - } +static SmallVector +genIndexAdds(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef lhs, ArrayRef rhs) { + // ensure a is longer than b + ArrayRef a = lhs.size() >= rhs.size() ? lhs : rhs; + ArrayRef b = lhs.size() >= rhs.size() ? rhs : lhs; + SmallVector results(a.take_front(a.size() - b.size())); + a = a.slice(a.size() - b.size()); + for (auto [l, r] : llvm::zip(a, b)) { + auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, l); + auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, r); + results.push_back(rewriter.createOrFold(loc, lval, rval)); } - return llvm::to_vector(llvm::reverse(reversedResult)); + return results; } -// A callback funtion type used to create new load/store_matrix ops -using CreatorFuncType = - llvm::function_ref baseOffsets, - SmallVector> &descOffsets)>; - -/// Utility helper for distributing logic shared by operations with offsets -template ::value>> +/// Utility helper for deriving a list of offsets for each sub-TensorDescs +/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the +/// associated distribute layout attribute, the shape, subgroup id and the +/// original offsets of the op +template < + typename OpType, + typename = std::enable_if_t::value>> static LogicalResult -distributeOp(ConversionPatternRewriter &rewriter, - typename OpConversionPattern::OneToNOpAdaptor adaptor, - OpType op, ArrayRef wgShape, CreatorFuncType callback) { +genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, + SmallVector> &offsetsList) { Location loc = op.getLoc(); - auto layout = op.getLayoutAttr(); + SmallVector origOffsets = op.getMixedOffsets(); + // not applicable to ops without offsets operands. + if (origOffsets.empty()) + return failure(); + + // not applicable to ops without workgroup layout attributes + xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); if (!layout || !layout.isWgLayout()) return failure(); @@ -133,12 +132,23 @@ distributeOp(ConversionPatternRewriter &rewriter, sgId = rewriter.create(loc, sgId, startOfRangeVal); } - auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); - if (failed(maybeMdescOffsets)) + // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory + // descriptors to be accessed, based on the layout information. + ArrayRef wgShape = op.getDistributeShape(); + auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(maybeDescOffsets)) return failure(); - SmallVector wgOffsets = op.getMixedOffsets(); - callback(wgOffsets, *maybeMdescOffsets); + // Compute the final global offsets for each accessed sub-tensor + // or sub-memory descriptor. + // SmallVector> offsetsList; + for (const auto &sgOffsets : *maybeDescOffsets) { + SmallVector newOffsets = + genIndexAdds(rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets); + offsetsList.push_back(std::move(newOffsets)); + } + + // callback(offsetsList); return success(); } @@ -193,44 +203,31 @@ struct WgToSgCreateNdOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - // Ensure that the op has explicit offsets specified (either dynamic or - // constant). - if (op.getMixedOffsets().empty()) + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); - Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); ArrayRef wgShape = tdescTy.getShape(); Type elemTy = tdescTy.getElementType(); + xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + auto newTdescTy = + xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), + layout.dropSgLayoutAndData()); - // the call back function for creating new CreateNdOps, - // the baseOffsets is the origial offsets of the op, and - // descOffsets is the relative offsets to the mem_desc accessed - // by each subgroup op. - auto callback = [&](ArrayRef baseOffsets, - SmallVector> descOffsets) { - xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); - SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; - auto newTdescTy = xegpu::TensorDescType::get( - ctx, sgShape, elemTy, tdescTy.getEncoding(), - layout.dropSgLayoutAndData()); - - SmallVector newOps; - for (auto offsets : descOffsets) { - SmallVector sgOffsets = - add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets)); - auto newOp = xegpu::CreateNdDescOp::create( - rewriter, loc, newTdescTy, op.getSource(), sgOffsets, - op.getMixedSizes(), op.getMixedStrides()); + SmallVector newOps; + for (auto offsets : offsetsList) { + auto newOp = xegpu::CreateNdDescOp::create( + rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets, + op.getMixedSizes(), op.getMixedStrides()); - newOps.push_back(newOp); - } - rewriter.replaceOpWithMultiple(op, {newOps}); - }; + newOps.push_back(newOp); + } + rewriter.replaceOpWithMultiple(op, {newOps}); - return distributeOp(rewriter, adaptor, op, wgShape, callback); + return success(); } }; @@ -283,12 +280,10 @@ struct WgToSgLoadNdOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector newLoadOps; - - int64_t offsetSize = static_cast(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) + if (!op.getMixedOffsets().empty()) return failure(); + SmallVector newLoadOps; for (auto src : adaptor.getTensorDesc()) { xegpu::TensorDescType tdescTy = dyn_cast(src.getType()); @@ -311,9 +306,7 @@ struct WgToSgStoreNdOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - int64_t offsetSize = static_cast(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) + if (!op.getMixedOffsets().empty()) return failure(); for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) @@ -325,100 +318,6 @@ struct WgToSgStoreNdOp : public OpConversionPattern { } }; -// Utility function to compute global offsets for subgroup operations. -// Returns a vector of new offsets for each subgroup, given the original op's -// offsets and subgroup relative offsets. -static SmallVector> -computeOffsets(Operation *op, ArrayRef> sgOffsetsList, - ArrayRef origOffsets, - ConversionPatternRewriter &rewriter) { - SmallVector> finalOffsets; - Location loc = op->getLoc(); - for (const auto &sgOffsets : sgOffsetsList) { - SmallVector newOffsets; - size_t rank = sgOffsets.size(); - for (size_t i = 0; i < rank; i++) { - size_t idx = origOffsets.size() - rank + i; - Value add = rewriter.createOrFold( - loc, sgOffsets[i], - getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx])); - newOffsets.push_back(add); - } - finalOffsets.push_back(std::move(newOffsets)); - } - return finalOffsets; -} - -// Utility function to get sgShape, sgOffsetList for a given -// op. -template -LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor, - ConversionPatternRewriter &rewriter, - SmallVector &sgShape, - SmallVector> &sgOffsetList) { - int64_t offsetSize = static_cast(op.getOffsets().size()); - if (offsetSize == 0 && (!op.getConstOffsetsAttr())) - return failure(); - - Location loc = op.getLoc(); - Value tdesc = op.getTensorDesc(); - auto tdescTy = dyn_cast(tdesc.getType()); - if (!tdescTy) - return failure(); - auto layout = dyn_cast(tdescTy.getLayout()); - if (!layout) - return failure(); - - SmallVector sgLayout; - auto sgLayoutAttr = layout.getSgLayout(); - if (!sgLayoutAttr) - return rewriter.notifyMatchFailure( - op, "sgLayout attribute is required in layout"); - sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); - - ArrayRef wgShape = tdescTy.getShape(); - int count; - std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); - - // Get the subgroup ID - Value linearSgId = - gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - - int64_t startOfRange = -1, endOfRange = -1; - bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange); - - if (sgIdRangeSpecified) { - int64_t sgCount = endOfRange - startOfRange; - if (computeProduct(sgLayout) != sgCount) - return rewriter.notifyMatchFailure( - op, "sg_layout size must match the sg_id_range"); - Value startOfRangeVal = - rewriter.create(loc, startOfRange); - linearSgId = - rewriter.createOrFold(loc, linearSgId, startOfRangeVal); - } - - auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape); - if (failed(sgOffsets)) - return failure(); - - sgOffsetList = *sgOffsets; - return success(); -} - -template -SmallVector getOffsets(OpTy op, - ConversionPatternRewriter &rewriter) { - SmallVector origOffsets; - if (auto constOffsets = op.getConstOffsetsAttr()) { - for (auto attr : constOffsets.asArrayRef()) - origOffsets.push_back(rewriter.getIndexAttr(attr)); - } - for (auto v : op.getOffsets()) - origOffsets.push_back(v); - return origOffsets; -} - // This pattern transforms the LoadNdOp with explicit offsets to load // subgroup data. struct WgToSgLoadNdOpWithOffset : public OpConversionPattern { @@ -427,33 +326,24 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern { matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector sgShape; - SmallVector> sgOffsetList; - - // Do the distribution from workgroup to subgroup and get subgroup offsets - if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); - // Get the original workgroup offsets - SmallVector origOffsets = getOffsets(op, rewriter); - - // Calculate the final offsets for each subgroup - auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); - - SmallVector newLoadOps; - for (auto [offsets, tdesc] : - llvm::zip(finalOffsets, adaptor.getTensorDesc())) { - VectorType newResTy = VectorType::get( - sgShape, - dyn_cast(tdesc.getType()).getElementType()); - auto newLoadOp = rewriter.create( - op.getLoc(), newResTy, tdesc, offsets, - /*packed=*/nullptr, - /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); - newLoadOps.push_back(newLoadOp); + SmallVector newOps; + for (auto [tdesc, offsets] : + llvm::zip(adaptor.getTensorDesc(), offsetsList)) { + auto tdescTy = dyn_cast(tdesc.getType()); + VectorType newResTy = + VectorType::get(tdescTy.getShape(), tdescTy.getElementType()); + auto newOp = xegpu::LoadNdOp::create( + rewriter, op.getLoc(), newResTy, tdesc, offsets, + /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + newOps.push_back(newOp); } - rewriter.replaceOpWithMultiple(op, {newLoadOps}); + rewriter.replaceOpWithMultiple(op, {newOps}); + return success(); } }; @@ -466,27 +356,18 @@ struct WgToSgStoreNdOpWithOffset LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - SmallVector sgShape; - SmallVector> sgOffsetList; - - // Do the distribution from workgroup to subgroup and get subgroup offsets - if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); - // Get the original workgroup offsets - SmallVector origOffsets = getOffsets(op, rewriter); - - // Calculate the final offsets for each subgroup - auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); - - for (auto [offsets, tdesc, value] : - llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) { - rewriter.create(op.getLoc(), value, tdesc, offsets, + for (auto [v, tdesc, offsets] : + llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { + rewriter.create(op.getLoc(), v, tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } rewriter.eraseOp(op); + return success(); } }; @@ -499,27 +380,18 @@ struct WgToSgPrefetchNdOpWithOffset LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - SmallVector sgShape; - SmallVector> sgOffsetList; - - // Do the distribution from workgroup to subgroup and get subgroup offsets - if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); - // Get the original workgroup offsets - SmallVector origOffsets = getOffsets(op, rewriter); - - // Calculate the final offsets for each subgroup - auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); - - for (auto [offsets, tdesc] : - llvm::zip(finalOffsets, adaptor.getTensorDesc())) { + for (auto [tdesc, offsets] : + llvm::zip(adaptor.getTensorDesc(), offsetsList)) { rewriter.create( op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } rewriter.eraseOp(op); + return success(); } }; @@ -918,34 +790,28 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); + + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + ArrayRef wgShape = op.getDistributeShape(); VectorType valueTy = op.getRes().getType(); - ArrayRef wgShape = valueTy.getShape(); Type elemTy = valueTy.getElementType(); - // the call back function for creating new LoadMatrixOps, - // the baseOffsets is the origial offsets of the op, and - // descOffsets is the relative offsets to the mem_desc accessed - // by each subgroup op. - auto callback = [&](ArrayRef baseOffsets, - SmallVector> descOffsets) { - auto layout = op.getLayoutAttr(); - SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; - VectorType newResTy = VectorType::get(sgShape, elemTy); - - SmallVector newOps; - for (auto offsets : descOffsets) { - SmallVector sgOffsets = - add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets)); - auto newOp = rewriter.create( - loc, newResTy, op.getMemDesc(), sgOffsets, - layout.dropSgLayoutAndData()); - newOps.push_back(newOp); - } - rewriter.replaceOpWithMultiple(op, {newOps}); - }; + xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResTy = VectorType::get(sgShape, elemTy); + SmallVector newOps; + for (auto offsets : offsetsList) { + auto newOp = rewriter.create( + op.getLoc(), newResTy, op.getMemDesc(), offsets, + layout.dropSgLayoutAndData()); + newOps.push_back(newOp); + } + rewriter.replaceOpWithMultiple(op, {newOps}); - return distributeOp(rewriter, adaptor, op, wgShape, callback); + return success(); } }; @@ -954,26 +820,18 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - VectorType valueTy = op.getData().getType(); - ArrayRef wgShape = valueTy.getShape(); - - // the call back function for creating new StoreMatrixOps, - // the baseOffsets is the origial offsets of the op, and - // descOffsets is the relative offsets to the mem_desc accessed - // by each subgroup op. - auto callback = [&](ArrayRef baseOffsets, - SmallVector> descOffsets) { - auto layout = op.getLayoutAttr(); - for (auto [v, descOffsets] : llvm::zip(adaptor.getData(), descOffsets)) { - SmallVector sgOffsets = - add(rewriter, loc, baseOffsets, getAsOpFoldResult(descOffsets)); - rewriter.create( - loc, v, op.getMemDesc(), sgOffsets, layout.dropSgLayoutAndData()); - } - rewriter.eraseOp(op); - }; - return distributeOp(rewriter, adaptor, op, wgShape, callback); + + SmallVector> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); + for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList)) + rewriter.create(op.getLoc(), v, op.getMemDesc(), + offsets, + layout.dropSgLayoutAndData()); + rewriter.eraseOp(op); + return success(); } }; diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index 7dcdcca070ac9..e5cc65e6bd3d7 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -30,13 +30,9 @@ gpu.module @test_round_robin_assignment { //CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index //CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index //CHECK: [[C128:%.+]] = arith.constant 128 : index - //CHECK: [[modY:%.+]] = index.remu [[ADDY]], [[C128]] + //CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]] //CHECK: [[C64_2:%.+]] = arith.constant 64 : index - //CHECK: [[modX:%.+]] = index.remu [[ADDX]], [[C64_2]] - //CHECK: [[C0_3:%.+]] = arith.constant 0 : index - //CHECK: [[offX:%.+]] = index.add [[modX]], [[C0_3]] - //CHECK: [[C0_4:%.+]] = arith.constant 0 : index - //CHECK: [[offY:%.+]] = index.add [[modY]], [[C0_4]] + //CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C64_2]] //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32> %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout> diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 07a0b86223c33..32157a7911f62 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -263,4 +263,62 @@ gpu.module @test_distribution { } {sg_id_range = #xegpu.range<[3, 19]>} gpu.return } + + // CHECK-LABEL: distribute_load_matrix + // CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3> + gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) { + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index + //CHECK: [[c2:%.+]] = arith.constant 2 : index + //CHECK: [[c4:%.+]] = arith.constant 4 : index + //CHECK: [[c4_0:%.+]] = arith.constant 4 : index + //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]] + //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]] + //CHECK: [[c32:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]] + //CHECK: [[c32_1:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]] + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[c0_1:%.+]] = arith.constant 0 : index + //CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index + //CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index + //CHECK: [[c64:%.+]] = arith.constant 64 : index + //CHECK: [[off_y:%.+]] = index.remu [[l_off_y_0]], [[c64]] + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[off_x:%.+]] = index.remu [[l_off_x_0]], [[c128]] + //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32> + %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32> + gpu.return + } + + //CHECK-LABEL: distribute_store_matrix + //CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3> + gpu.func @distribute_store_matrix(%arg0 : memref<32768xi8, 3>) { + //CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32> + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index + //CHECK: [[c2:%.+]] = arith.constant 2 : index + //CHECK: [[c4:%.+]] = arith.constant 4 : index + //CHECK: [[c4_0:%.+]] = arith.constant 4 : index + //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]] + //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]] + //CHECK: [[c32:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]] + //CHECK: [[c32_1:%.+]] = arith.constant 32 : index + //CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]] + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[c0_2:%.+]] = arith.constant 0 : index + //CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index + //CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index + //CHECK: [[c64:%.+]] = arith.constant 64 : index + //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]] + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]] + //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index + %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<1.0> : vector<64x128xf32> + %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> + xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index bdda77a69f22e..f4a49da71605f 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -17,13 +17,9 @@ gpu.module @test_1_1_assignment { //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index //CHECK: [[C256:%.+]] = arith.constant 256 : index - //CHECK: [[MODY:%.+]] = index.remu [[UY]], [[C256]] + //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]] //CHECK: [[C128:%.+]] = arith.constant 128 : index - //CHECK: [[MODX:%.+]] = index.remu [[UX]], [[C128]] - //CHECK: [[C0_3:%.+]] = arith.constant 0 : index - //CHECK: [[X:%.+]] = index.add [[MODX]], [[C0_3]] - //CHECK: [[C0_4:%.+]] = arith.constant 0 : index - //CHECK: [[Y:%.+]] = index.add [[MODY]], [[C0_4]] + //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]] //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout> %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> @@ -376,72 +372,4 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<1.0> : vector<256x128xf32> gpu.return } - - // CHECK-LABEL: distribute_load_matrix - // CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3> - gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) { - //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> - //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index - //CHECK: [[c2:%.+]] = arith.constant 2 : index - //CHECK: [[c4:%.+]] = arith.constant 4 : index - //CHECK: [[c4_0:%.+]] = arith.constant 4 : index - //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]] - //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]] - //CHECK: [[c32:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]] - //CHECK: [[c32_1:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]] - //CHECK: [[c0:%.+]] = arith.constant 0 : index - //CHECK: [[c0_1:%.+]] = arith.constant 0 : index - //CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index - //CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index - //CHECK: [[c64:%.+]] = arith.constant 64 : index - //CHECK: [[mod_y:%.+]] = index.remu [[l_off_y_0]], [[c64]] - //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x_0]], [[c128]] - //CHECK: [[c0_2:%.+]] = arith.constant 0 : index - //CHECK: [[off_x:%.+]] = index.add [[mod_x]], [[c0_2]] - //CHECK: [[c0_3:%.+]] = arith.constant 0 : index - //CHECK: [[off_y:%.+]] = index.add [[mod_y]], [[c0_3]] - //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32> - %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> - %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32> - gpu.return - } - - //CHECK-LABEL: distribute_store_matrix - //CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3> - gpu.func @distribute_store_matrix(%arg0 : memref<32768xi8, 3>) { - //CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32> - //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> - //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index - //CHECK: [[c2:%.+]] = arith.constant 2 : index - //CHECK: [[c4:%.+]] = arith.constant 4 : index - //CHECK: [[c4_0:%.+]] = arith.constant 4 : index - //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]] - //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]] - //CHECK: [[c32:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]] - //CHECK: [[c32_1:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]] - //CHECK: [[c0:%.+]] = arith.constant 0 : index - //CHECK: [[c0_2:%.+]] = arith.constant 0 : index - //CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index - //CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index - //CHECK: [[c64:%.+]] = arith.constant 64 : index - //CHECK: [[mod_y:%.+]] = index.remu [[l_off_y]], [[c64]] - //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x]], [[c128]] - //CHECK: [[c0_3:%.+]] = arith.constant 0 : index - //CHECK: [[off_x:%.+]] = index.add [[mod_x]], [[c0_3]] - //CHECK: [[c0_4:%.+]] = arith.constant 0 : index - //CHECK: [[off_y:%.+]] = index.add [[mod_y]], [[c0_4]] - //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index - %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<1.0> : vector<64x128xf32> - %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> - xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32> - - gpu.return - } - } From ce07282d88568f25f8c6fb29c7327f1e26624e1d Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 20 Aug 2025 16:56:29 +0000 Subject: [PATCH 05/12] rename isWgLayout to isForWorkgroup --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 16 ++++++++-------- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 6 +++--- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 ++-- .../Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 10 +++++----- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 11 ++++++----- mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 2 +- .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 4 ++-- 7 files changed, 27 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index de86141ad006a..fe1f127bcd6b6 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -184,7 +184,7 @@ def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface" let methods = [ InterfaceMethod<"Check the availability of workgroup level layouts", "bool", - "isWgLayout">, + "isForWorkgroup">, InterfaceMethod<"Get the rank of attribute", "int64_t", "getRank">, @@ -337,12 +337,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterf ]; let extraClassDeclaration = [{ - bool isWgLayout() { + bool isForWorkgroup() { return getSgLayout() != nullptr; } - bool isSgLayout() { - return !isWgLayout(); + bool isForSubgroup() { + return !isForWorkgroup(); } int64_t getRank() { @@ -454,16 +454,16 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface return parent.getOrder(); } - bool isWgLayout() const { + bool isForWorkgroup() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - return parent.isWgLayout(); + return parent.isForWorkgroup(); } - bool isSgLayout() const { + bool isForSubgroup() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - return parent.isSgLayout(); + return parent.isForSubgroup(); } int64_t getNumSubgroups() { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index de118b7faea4d..9e6702dda2de3 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -271,7 +271,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId) { // delinearizeSubgroupId is only available for // workgroup-level layout attribute - if (!isWgLayout()) + if (!isForWorkgroup()) return failure(); // TODO: handle order attribute @@ -296,7 +296,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, FailureOr>> LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape) { - if (!isWgLayout()) + if (!isForWorkgroup()) return failure(); SmallVector sgLayout = getSgLayoutAsInt().value(); @@ -384,7 +384,7 @@ FailureOr>> SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape) { assert(getRank() == static_cast(shape.size()) && "invalid shape."); - if (!isWgLayout()) + if (!isForWorkgroup()) return failure(); SmallVector sgLayout = getSgLayoutAsInt().value(); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 0e22af900daf1..ff538ebed4bad 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -938,8 +938,8 @@ LogicalResult ConvertLayoutOp::verify() { // both input and target layouts should be WgLayout or SgLayout at the same // time. - if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) && - (!srcLayout.isSgLayout() || !resLayout.isSgLayout())) + if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) && + (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup())) return emitOpError("expected input layout and target layout be WgLayout or " "SgLayout at the same time."); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index d82c541f31359..b3144e4c1e55d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -141,7 +141,7 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { value = (Value)operandOrResult; xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult); - if (layout && layout.isSgLayout()) { + if (layout && layout.isForSubgroup()) { if (auto inst_data = layout.getInstData()) return llvm::to_vector_of(inst_data.asArrayRef()); @@ -205,12 +205,12 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { bool hasWgLayoutOperands = llvm::any_of(op->getOpOperands(), [](OpOperand &opr) { xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr); - return layout && layout.isWgLayout(); + return layout && layout.isForWorkgroup(); }); bool hasWgLayoutResults = llvm::any_of(op->getOpResults(), [](OpResult result) { xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result); - return layout && layout.isWgLayout(); + return layout && layout.isForWorkgroup(); }); if (hasWgLayoutOperands || hasWgLayoutResults) { LDBG() << "skip unrolling for op with workgroup level layout: " << *op; @@ -272,7 +272,7 @@ void XeGPUBlockingPass::runOnOperation() { auto layout = llvm::dyn_cast_if_present(type.getEncoding()); - if (layout && layout.isWgLayout()) + if (layout && layout.isForWorkgroup()) return failure(); int count; @@ -289,7 +289,7 @@ void XeGPUBlockingPass::runOnOperation() { ArrayRef shape = type.getShape(); xegpu::LayoutAttr layout = type.getLayoutAttr(); - if (layout && layout.isWgLayout()) + if (layout && layout.isForWorkgroup()) return failure(); int count; diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 76bda64ffac1e..55957d9b264fc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -59,7 +59,7 @@ getSgShapeAndCount(ArrayRef shape, xegpu::DistributeLayoutAttrInterface layout) { int count = 1; SmallVector sgShape(shape); - if (layout && layout.isWgLayout()) { + if (layout && layout.isForWorkgroup()) { SmallVector sgLayout = layout.getSgLayoutAsInt().value(); if (auto maybeSgData = layout.getSgDataAsInt()) sgShape = *maybeSgData; @@ -115,7 +115,7 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, // not applicable to ops without workgroup layout attributes xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); - if (!layout || !layout.isWgLayout()) + if (!layout || !layout.isForWorkgroup()) return failure(); Value sgId = rewriter.create(loc, /*upper_bound=*/nullptr); @@ -249,7 +249,7 @@ struct WgToSgCreateNdOpNoOffset MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); auto layout = dyn_cast(tdescTy.getLayout()); - if (!layout || !layout.isWgLayout()) + if (!layout || !layout.isForWorkgroup()) return failure(); Type elemTy = tdescTy.getElementType(); @@ -637,7 +637,8 @@ struct WgToSgConvertLayoutOp xegpu::LayoutAttr input = op.getInputLayout(); xegpu::LayoutAttr target = op.getTargetLayout(); - if (!input || !target || !input.isWgLayout() || !target.isWgLayout()) + if (!input || !target || !input.isForWorkgroup() || + !target.isForWorkgroup()) return rewriter.notifyMatchFailure( op, "Input and target layouts must have subgroup layout"); @@ -938,7 +939,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { }; auto isLegal = [&](xegpu::DistributeLayoutAttrInterface layout) -> bool { - return !layout || !layout.isWgLayout(); + return !layout || !layout.isForWorkgroup(); }; target.addDynamicallyLegalOp(tdescTy.getLayout()); // It only works for subgroup level layout, which only has lane_layout // and lane_data, and is to distribute a SIMD code into SIMT code. - if (!layout || !layout.isSgLayout()) + if (!layout || !layout.isForSubgroup()) return failure(); SmallVector laneData(layout.getLaneData().asArrayRef()); diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 8d2fb85655c72..86bb3af326da2 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -82,7 +82,7 @@ struct TestXeGPUUnrollingPatterns if (auto layout = tdescTy.getLayoutAttr()) { auto inst_data = layout.getInstData(); - if (inst_data && layout.isSgLayout()) + if (inst_data && layout.isForSubgroup()) return SmallVector(inst_data.asArrayRef().begin(), inst_data.asArrayRef().end()); } @@ -239,7 +239,7 @@ struct TestXeGPULayoutInterface ConversionTarget target(*ctx); auto isLegal = [&](xegpu::SliceAttr layout) -> bool { - return !layout || !layout.isWgLayout(); + return !layout || !layout.isForWorkgroup(); }; target.addDynamicallyLegalOp( From 9ae490cf05850458f7f81635ba4be21dc8d26ac1 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 20 Aug 2025 17:37:58 +0000 Subject: [PATCH 06/12] cleanup getNumSubgroups --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index fe1f127bcd6b6..0fe4e22f50376 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -190,7 +190,12 @@ def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface" "getRank">, InterfaceMethod<"Get the num of effective subgroups", "int64_t", - "getNumSubgroups">, + "getNumSubgroups", (ins), [{ + std::optional> sgLayout = llvm::cast(tablegen_opaque_val).getSgLayoutAsInt(); + if (sgLayout.has_value()) + return computeProduct(*sgLayout); + return 0; + }], [{}]>, InterfaceMethod<"Get the SgLayout field of the attribute as integer array", "std::optional>", "getSgLayoutAsInt">, @@ -355,13 +360,6 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterf return 0; } - int64_t getNumSubgroups() { - std::optional> sgLayout = getSgLayoutAsInt(); - if (sgLayout.has_value()) - return computeProduct(*sgLayout); - return 0; - } - LayoutAttr dropSgLayoutAndData() { // avoid every field of the attribute is nullptr, which may lead to segment fault if (!getInstData() && !getLaneLayout()) @@ -466,13 +464,6 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface return parent.isForSubgroup(); } - int64_t getNumSubgroups() { - std::optional> sgLayout = getSgLayoutAsInt(); - if (sgLayout.has_value()) - return computeProduct(*sgLayout); - return 0; - } - /// Returns the SgLayout of the attribute, computed by applying /// the slice dimensions to the underlying LayoutAttr. std::optional> getSgLayoutAsInt() const { From 36e3e3d632e36c9fcbc98791389020e088d7285f Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 20 Aug 2025 17:45:39 +0000 Subject: [PATCH 07/12] update comments --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 55957d9b264fc..c5b497dcc695e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -76,9 +76,15 @@ getSgShapeAndCount(ArrayRef shape, return std::make_pair(sgShape, count); } -// An util helper to generate elementwise addition ops for index computing. -// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match. -// left-alignment is performed. +/// Generates element-wise addition ops of two arrays with automatic alignment. +/// When the input arrays have different sizes, the shorter array is right-aligned +/// with the longer array, and the unmatched leading elements from the longer array +/// are preserved unchanged. This is commonly used for offset computation where +/// higher-dimensional offsets need to be added to lower-dimensional adjustments. +/// +/// Example: +/// lhs = [10, 20, 30], rhs = [5, 7] +/// Result: [10, 25, 37] (20+5, 30+7, with 10 preserved) static SmallVector genIndexAdds(ConversionPatternRewriter &rewriter, Location loc, ArrayRef lhs, ArrayRef rhs) { From 6d0458f2f5f2d5424a4cad7567309c6ababf787e Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 20 Aug 2025 17:49:27 +0000 Subject: [PATCH 08/12] cleanup --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h | 2 +- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 10 ++++---- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 24 +++++++++---------- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 9 ++++--- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 ++-- .../Transforms/XeGPUWgToSgDistribute.cpp | 12 +++++----- .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 2 +- 7 files changed, 31 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h index 1d152f0c9ca9a..1481859e94a92 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h @@ -24,7 +24,7 @@ namespace mlir { namespace xegpu { class TensorDescType; -class DistributeLayoutAttrInterface; +class DistributeLayoutAttr; class LayoutAttr; class SliceAttr; } // namespace xegpu diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 0fe4e22f50376..b4d696444cc44 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -175,7 +175,7 @@ def XeGPU_FenceScopeAttr: let assemblyFormat = "$value"; } -def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"> { +def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { let cppNamespace = "::mlir::xegpu"; let description = [{ Common trait for all XeGPU layouts. @@ -203,7 +203,7 @@ def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface" "std::optional>", "getSgDataAsInt">, InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData", - "xegpu::DistributeLayoutAttrInterface", + "xegpu::DistributeLayoutAttr", "dropSgLayoutAndData">, InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional indices based on the effective subgroup layout.}], @@ -220,7 +220,7 @@ def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface" ]; } -def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterface]> { +def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> { let summary = [{ Describes the data distribution to subgroups and work-items for a tensor specified by the tensor descriptor. @@ -407,7 +407,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterf } -def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface]> { +def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { let summary = [{Describes the data distribution and sharing among subgroups or work-items.}]; let description = [{ @@ -434,7 +434,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface }]; let parameters = (ins - "xegpu::DistributeLayoutAttrInterface": $parent, + "xegpu::DistributeLayoutAttr": $parent, "DenseI64ArrayAttr": $dims ); diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 3182552288ca6..f3eaf400e1e4c 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -232,8 +232,8 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface return static_cast(MemorySpace::Global); } - xegpu::DistributeLayoutAttrInterface getLayoutAttr() { - return dyn_cast_if_present(getType().getLayout()); + xegpu::DistributeLayoutAttr getLayoutAttr() { + return dyn_cast_if_present(getType().getLayout()); } ArrayRef getDistributeShape() { @@ -279,8 +279,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { return getMixedValues(statics, dynamics, getContext()); } - xegpu::DistributeLayoutAttrInterface getLayoutAttr() { - return dyn_cast_if_present(getTensorDescType().getLayout()); + xegpu::DistributeLayoutAttr getLayoutAttr() { + return dyn_cast_if_present(getTensorDescType().getLayout()); } ArrayRef getDistributeShape() { @@ -377,8 +377,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ return getMixedValues(statics, dynamics, getContext()); } - xegpu::DistributeLayoutAttrInterface getLayoutAttr() { - return dyn_cast_if_present(getTensorDescType().getLayout()); + xegpu::DistributeLayoutAttr getLayoutAttr() { + return dyn_cast_if_present(getTensorDescType().getLayout()); } ArrayRef getDistributeShape() { @@ -469,8 +469,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ return getMixedValues(statics, dynamics, getContext()); } - xegpu::DistributeLayoutAttrInterface getLayoutAttr() { - return dyn_cast_if_present(getTensorDescType().getLayout()); + xegpu::DistributeLayoutAttr getLayoutAttr() { + return dyn_cast_if_present(getTensorDescType().getLayout()); } ArrayRef getDistributeShape() { @@ -1211,7 +1211,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$layout + OptionalAttr:$layout ); let results = (outs XeGPU_ValueType:$res); let assemblyFormat = [{ @@ -1236,7 +1236,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, let builders = [ OpBuilder<(ins "Type":$res, "TypedValue": $mem_desc, - "llvm::ArrayRef": $offsets, "DistributeLayoutAttrInterface": $layout)>, + "llvm::ArrayRef": $offsets, "DistributeLayoutAttr": $layout)>, ]; let extraClassDeclaration = [{ SmallVector getMixedOffsets() { @@ -1259,7 +1259,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$layout + OptionalAttr:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands)}]; @@ -1278,7 +1278,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, }]; let builders = [ OpBuilder<(ins "Value" : $data, "TypedValue": $mem_desc, - "llvm::ArrayRef": $offsets, "DistributeLayoutAttrInterface": $layout)>, + "llvm::ArrayRef": $offsets, "DistributeLayoutAttr": $layout)>, ]; let extraClassDeclaration = [{ SmallVector getMixedOffsets() { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 9e6702dda2de3..a2d708be0e937 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -290,7 +290,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, return affine::delinearizeIndex(builder, loc, linearId, dims); } -/// Implements DistributeLayoutAttrInterface::getOffsets to generate +/// Implements DistributeLayoutAttr::getOffsets to generate /// instructions for computing multi-dimensional offsets when distributed by /// LayoutAttr. FailureOr>> @@ -323,8 +323,7 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, //===----------------------------------------------------------------------===// LogicalResult SliceAttr::verify(llvm::function_ref emitError, - xegpu::DistributeLayoutAttrInterface parent, - DenseI64ArrayAttr dims) { + xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) { if (!parent || !dims) return emitError() << "expected parent layout and dims attribute"; @@ -342,7 +341,7 @@ SliceAttr::verify(llvm::function_ref emitError, } SliceAttr SliceAttr::flatten() const { - xegpu::DistributeLayoutAttrInterface parent = getParent(); + xegpu::DistributeLayoutAttr parent = getParent(); SmallVector slicedDims({getDims()}); while (auto sliceAttr = dyn_cast(parent)) { @@ -377,7 +376,7 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, return parent.delinearizeSubgroupId(builder, loc, linearId); } -/// Implements DistributeLayoutAttrInterface::getOffsets to generate +/// Implements DistributeLayoutAttr::getOffsets to generate /// instructions for computing multi-dimensional offsets when distributed by /// SliceAttr. FailureOr>> diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index ff538ebed4bad..c8d180b973f05 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, TypedValue memDesc, llvm::ArrayRef offsets, - DistributeLayoutAttrInterface layout) { + DistributeLayoutAttr layout) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() { void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, TypedValue memDesc, llvm::ArrayRef offsets, - DistributeLayoutAttrInterface layout) { + DistributeLayoutAttr layout) { llvm::SmallVector dynamicOffsets; llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index c5b497dcc695e..09aa1e61c20e6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -56,7 +56,7 @@ static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, static std::pair, int> getSgShapeAndCount(ArrayRef shape, - xegpu::DistributeLayoutAttrInterface layout) { + xegpu::DistributeLayoutAttr layout) { int count = 1; SmallVector sgShape(shape); if (layout && layout.isForWorkgroup()) { @@ -120,7 +120,7 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, return failure(); // not applicable to ops without workgroup layout attributes - xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -217,7 +217,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern { xegpu::TensorDescType tdescTy = op.getType(); ArrayRef wgShape = tdescTy.getShape(); Type elemTy = tdescTy.getElementType(); - xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; auto newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), @@ -806,7 +806,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern { VectorType valueTy = op.getRes().getType(); Type elemTy = valueTy.getElementType(); - xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; VectorType newResTy = VectorType::get(sgShape, elemTy); SmallVector newOps; @@ -832,7 +832,7 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern { if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); - xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList)) rewriter.create(op.getLoc(), v, op.getMemDesc(), offsets, @@ -944,7 +944,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return xegpu::TensorDescType(); }; - auto isLegal = [&](xegpu::DistributeLayoutAttrInterface layout) -> bool { + auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool { return !layout || !layout.isForWorkgroup(); }; diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 86bb3af326da2..200323c7a4e51 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -156,7 +156,7 @@ struct TestXeGPUUnrollingPatterns #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") // Test pattern for distributing vector::StepOp from workgroup to subgroup. -// Validates DistributeLayoutAttrInterface interfaces for offset computation +// Validates DistributeLayoutAttr interfaces for offset computation // abstraction between LayoutAttr and SliceAttr. class TestStepOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; From 69ff3ca5e524875d08d0a39000abddb6045287e5 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 20 Aug 2025 18:01:56 +0000 Subject: [PATCH 09/12] fix format --- .../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 09aa1e61c20e6..2b4bfc0b15778 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -77,10 +77,11 @@ getSgShapeAndCount(ArrayRef shape, } /// Generates element-wise addition ops of two arrays with automatic alignment. -/// When the input arrays have different sizes, the shorter array is right-aligned -/// with the longer array, and the unmatched leading elements from the longer array -/// are preserved unchanged. This is commonly used for offset computation where -/// higher-dimensional offsets need to be added to lower-dimensional adjustments. +/// When the input arrays have different sizes, the shorter array is +/// right-aligned with the longer array, and the unmatched leading elements from +/// the longer array are preserved unchanged. This is commonly used for offset +/// computation where higher-dimensional offsets need to be added to +/// lower-dimensional adjustments. /// /// Example: /// lhs = [10, 20, 30], rhs = [5, 7] From 4f93bcbc6f8244ab00514b4d3a32d22f97b16428 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 20 Aug 2025 22:14:08 +0000 Subject: [PATCH 10/12] rename genIndexAdd --- .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 15 ++++++++++ .../Transforms/XeGPUWgToSgDistribute.cpp | 30 ++----------------- mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 29 ++++++++++++++++++ 3 files changed, 46 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index db8608c6d20b8..b2b2d3ab85231 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_ #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" namespace mlir { class VectorType; @@ -128,6 +129,20 @@ void doSCFStructuralTypeConversionWithTensorType(Operation *op, /// if no GPU module parent or XeVM target attribute exists. std::optional getChipStr(Operation *op); +/// Generates element-wise addition ops of two arrays with automatic alignment. +/// When the input arrays have different sizes, the shorter array is +/// right-aligned with the longer array, and the unmatched leading elements from +/// the longer array are preserved unchanged. This is commonly used for offset +/// computation where higher-dimensional offsets need to be added to +/// lower-dimensional adjustments. +/// +/// Example: +/// lhs = [l1, l2, l3], rhs = [r1, r2] +/// Result: [11, l2+r1, l3+r2] +SmallVector addWithRightAligned(OpBuilder &builder, Location loc, + ArrayRef lhs, + ArrayRef rhs); + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 2b4bfc0b15778..b22aa42a0fd01 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -76,32 +76,6 @@ getSgShapeAndCount(ArrayRef shape, return std::make_pair(sgShape, count); } -/// Generates element-wise addition ops of two arrays with automatic alignment. -/// When the input arrays have different sizes, the shorter array is -/// right-aligned with the longer array, and the unmatched leading elements from -/// the longer array are preserved unchanged. This is commonly used for offset -/// computation where higher-dimensional offsets need to be added to -/// lower-dimensional adjustments. -/// -/// Example: -/// lhs = [10, 20, 30], rhs = [5, 7] -/// Result: [10, 25, 37] (20+5, 30+7, with 10 preserved) -static SmallVector -genIndexAdds(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef lhs, ArrayRef rhs) { - // ensure a is longer than b - ArrayRef a = lhs.size() >= rhs.size() ? lhs : rhs; - ArrayRef b = lhs.size() >= rhs.size() ? rhs : lhs; - SmallVector results(a.take_front(a.size() - b.size())); - a = a.slice(a.size() - b.size()); - for (auto [l, r] : llvm::zip(a, b)) { - auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, l); - auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, r); - results.push_back(rewriter.createOrFold(loc, lval, rval)); - } - return results; -} - /// Utility helper for deriving a list of offsets for each sub-TensorDescs /// or sub-MemDescs to be accessed by current subgroup (sgId) based on the /// associated distribute layout attribute, the shape, subgroup id and the @@ -150,8 +124,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, // or sub-memory descriptor. // SmallVector> offsetsList; for (const auto &sgOffsets : *maybeDescOffsets) { - SmallVector newOffsets = - genIndexAdds(rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets); + SmallVector newOffsets = xegpu::addWithRightAligned( + rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets); offsetsList.push_back(std::move(newOffsets)); } diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 535e2b10353c9..6835f64ad8ef7 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -424,3 +425,31 @@ std::optional xegpu::getChipStr(Operation *op) { return std::nullopt; } + +/// Generates element-wise addition ops of two arrays with automatic alignment. +/// When the input arrays have different sizes, the shorter array is +/// right-aligned with the longer array, and the unmatched leading elements from +/// the longer array are preserved unchanged. This is commonly used for offset +/// computation where higher-dimensional offsets need to be added to +/// lower-dimensional adjustments. +/// +/// Example: +/// lhs = [l1, l2, l3], rhs = [r1, r2] +/// Result: [11, l2+r1, l3+r2] +SmallVector +xegpu::addWithRightAligned(OpBuilder &builder, Location loc, + ArrayRef lhs, + ArrayRef rhs) { + // ensure a is longer than b + ArrayRef a = lhs.size() >= rhs.size() ? lhs : rhs; + ArrayRef b = lhs.size() >= rhs.size() ? rhs : lhs; + SmallVector results(a.take_front(a.size() - b.size())); + a = a.slice(a.size() - b.size()); + for (auto [l, r] : llvm::zip(a, b)) { + auto lval = getValueOrCreateConstantIndexOp(builder, loc, l); + auto rval = getValueOrCreateConstantIndexOp(builder, loc, r); + results.push_back(builder.createOrFold(loc, lval, rval)); + } + return results; + return {}; +} From 2fb9ac7e6c775a4c70c1216874d83baa7118e84c Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 21 Aug 2025 14:32:04 +0000 Subject: [PATCH 11/12] refactor --- .../Transforms/XeGPUWgToSgDistribute.cpp | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index b22aa42a0fd01..b6b42ca40f1d6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -34,24 +34,19 @@ using namespace mlir; namespace { -// Check if there is sg id range attached to the scf.if op. -static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, - int64_t &endOfRange) { +// Retrieve the RangeAttr if it is specified. +static xegpu::RangeAttr getRangeSpecAttr(Operation *op) { Operation *parent = op->getParentOp(); - // Find the outermost scf::IfOp with xegpu.sg_id_range. while (parent) { if (auto ifOp = dyn_cast(parent)) { if (auto attr = llvm::dyn_cast_or_null( ifOp->getAttr("sg_id_range"))) { - startOfRange = attr.getStart().getInt(); - endOfRange = attr.getEnd().getInt(); - break; + return attr; } } parent = parent->getParentOp(); } - // Return false if startOfRange is 0 - return (startOfRange > 0 && endOfRange > startOfRange); + return {}; } static std::pair, int> @@ -101,16 +96,21 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, Value sgId = rewriter.create(loc, /*upper_bound=*/nullptr); - // adjust the linearId if the range specifier is present - int64_t startOfRange = -1, endOfRange = -1; - bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange); - if (sgIdRangeSpecified) { + // verify and adjust the sgId if the range specifier is present + xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op); + if (sgIdRange) { + int64_t startOfRange = sgIdRange.getStart().getInt(); + int64_t endOfRange = sgIdRange.getEnd().getInt(); + // verify the RangeAttr against the layout attribute if (layout.getNumSubgroups() != endOfRange - startOfRange) return rewriter.notifyMatchFailure( op, "sg_layout size must match the sg_id_range"); - Value startOfRangeVal = - rewriter.create(loc, startOfRange); - sgId = rewriter.create(loc, sgId, startOfRangeVal); + // adjust the sgId if necessary + if (startOfRange > 0) { + Value startOfRangeVal = + rewriter.create(loc, startOfRange); + sgId = rewriter.create(loc, sgId, startOfRangeVal); + } } // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory From af6f83f71002c45c98da61c28635e986420485f4 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 21 Aug 2025 14:47:30 +0000 Subject: [PATCH 12/12] refactor --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 12 ++++++------ .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 18 +++++++----------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index f3eaf400e1e4c..ab471a1f33ef9 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -236,7 +236,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface return dyn_cast_if_present(getType().getLayout()); } - ArrayRef getDistributeShape() { + ArrayRef getDataShape() { return getTensorDescShape(); } @@ -283,7 +283,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { return dyn_cast_if_present(getTensorDescType().getLayout()); } - ArrayRef getDistributeShape() { + ArrayRef getDataShape() { return getTensorDescType().getShape(); } @@ -381,7 +381,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ return dyn_cast_if_present(getTensorDescType().getLayout()); } - ArrayRef getDistributeShape() { + ArrayRef getDataShape() { return getTensorDescType().getShape(); } @@ -473,7 +473,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ return dyn_cast_if_present(getTensorDescType().getLayout()); } - ArrayRef getDistributeShape() { + ArrayRef getDataShape() { return getTensorDescType().getShape(); } @@ -1243,7 +1243,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, return getMixedValues(getConstOffsets(), getOffsets(), getContext()); } - ArrayRef getDistributeShape() { + ArrayRef getDataShape() { return getRes().getType().getShape(); } }]; @@ -1285,7 +1285,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, return getMixedValues(getConstOffsets(), getOffsets(), getContext()); } - ArrayRef getDistributeShape() { + ArrayRef getDataShape() { return getData().getType().getShape(); } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index b6b42ca40f1d6..93b4efcd125ec 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -36,15 +36,12 @@ namespace { // Retrieve the RangeAttr if it is specified. static xegpu::RangeAttr getRangeSpecAttr(Operation *op) { - Operation *parent = op->getParentOp(); + Operation *parent = op->getParentOfType(); while (parent) { - if (auto ifOp = dyn_cast(parent)) { - if (auto attr = llvm::dyn_cast_or_null( - ifOp->getAttr("sg_id_range"))) { - return attr; - } - } - parent = parent->getParentOp(); + if (auto attr = llvm::dyn_cast_if_present( + parent->getAttr("sg_id_range"))) + return attr; + parent = parent->getParentOfType(); } return {}; } @@ -115,14 +112,13 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory // descriptors to be accessed, based on the layout information. - ArrayRef wgShape = op.getDistributeShape(); + ArrayRef wgShape = op.getDataShape(); auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); if (failed(maybeDescOffsets)) return failure(); // Compute the final global offsets for each accessed sub-tensor // or sub-memory descriptor. - // SmallVector> offsetsList; for (const auto &sgOffsets : *maybeDescOffsets) { SmallVector newOffsets = xegpu::addWithRightAligned( rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets); @@ -777,7 +773,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern { if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); - ArrayRef wgShape = op.getDistributeShape(); + ArrayRef wgShape = op.getDataShape(); VectorType valueTy = op.getRes().getType(); Type elemTy = valueTy.getElementType();