diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 1a6a34c8d775a..480b43e740736 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -142,11 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface let builders = [ OpBuilder<(ins "Type": $tdesc, "TypedValue": $source)>, - OpBuilder<(ins "Type": $tdesc, "TypedValue ": $source, - "llvm::ArrayRef": $shape, - "llvm::ArrayRef": $strides)>, - - OpBuilder<(ins "Type": $tdesc, "TypedValue ": $source, + OpBuilder<(ins "Type": $tdesc, "Value ": $source, "llvm::ArrayRef": $shape, "llvm::ArrayRef": $strides)>, diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 2cd086feb5deb..4dd937eb5114d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -156,41 +156,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, - Type tdesc, TypedValue source, + Type tdesc, Value source, llvm::ArrayRef shape, llvm::ArrayRef strides) { - assert(shape.size() && strides.size() && shape.size() == strides.size() && - "Shape and strides must be present and of equal size for ui64 " - "initialization."); + Type srcTy = source.getType(); + assert((isa(srcTy)) && + "Source has to be either int or memref."); - llvm::SmallVector staticShape; - llvm::SmallVector staticStrides; llvm::SmallVector dynamicShape; llvm::SmallVector dynamicStrides; - dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - - auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); - auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); - - build(builder, state, tdesc, source, ValueRange({}), dynamicShape, - dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, - staticStridesAttr); -} - -void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, - Type tdesc, TypedValue source, - llvm::ArrayRef shape, - llvm::ArrayRef strides) { - assert(shape.size() && strides.size() && shape.size() == strides.size() && - "Shape and strides must be present and of equal size for ui64 " - "initialization."); - llvm::SmallVector staticShape; llvm::SmallVector staticStrides; - llvm::SmallVector dynamicShape; - llvm::SmallVector dynamicStrides; dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); @@ -198,6 +175,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); + if (auto memrefTy = dyn_cast(srcTy)) { + auto memrefShape = memrefTy.getShape(); + auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); + + // if shape and strides are from Memref, we don't need attributes for them + // to keep the IR print clean. + if (staticShape == memrefShape && staticStrides == memrefStrides) { + staticShapeAttr = DenseI64ArrayAttr(); + staticStridesAttr = DenseI64ArrayAttr(); + } + } + build(builder, state, tdesc, source, ValueRange({}), dynamicShape, dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, staticStridesAttr); @@ -357,13 +346,10 @@ ParseResult parseOptionalDynamicIndexList( void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers) { - - if (!integers) + if (!integers || integers.empty()) return; - - return printDynamicIndexList(printer, op, values, integers, - /*scalableFlags=*/{}, {}, - AsmParser::Delimiter::Square); + printDynamicIndexList(printer, op, values, integers, + /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square); } //===----------------------------------------------------------------------===// // XeGPU_PrefetchNdOp diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 4a5525c8abb30..97c97ac3fd680 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -128,6 +128,12 @@ struct WgToSgCreateNdOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + + // Ensure that the op has explicit offsets specified (either dynamic or + // constant). + if (op.getMixedOffsets().empty()) + return failure(); + Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); @@ -199,6 +205,49 @@ struct WgToSgCreateNdOp : public OpConversionPattern { } }; +// This pattern transforms the CreateNdDescOp without offsets to create a +// subgroup descriptor from a workgroup descriptor +struct WgToSgCreateNdOpNoOffset + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Check no offsets are specified. + if (!op.getMixedOffsets().empty()) + return failure(); + + Location loc = op.getLoc(); + MLIRContext *ctx = op.getContext(); + xegpu::TensorDescType tdescTy = op.getType(); + auto layout = dyn_cast(tdescTy.getLayout()); + if (!layout || !layout.isWgLayout()) + return failure(); + + Type elemTy = tdescTy.getElementType(); + ArrayRef wgShape = tdescTy.getShape(); + + SmallVector sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); + xegpu::TensorDescType newTdescTy = + xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), + layout.dropSgLayoutAndData()); + + SmallVector newCreateNdOps(count); + std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() { + return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, + op.getSource(), op.getMixedSizes(), + op.getMixedStrides()); + }); + + rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); + return success(); + } +}; + /// This pattern transforms the LoadNdOp to load subgroup data. struct WgToSgLoadNdOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -603,11 +652,12 @@ struct UnrealizedConversionCastOpPattern namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns + .add( + patterns.getContext()); } } // namespace xegpu } // namespace mlir diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir new file mode 100644 index 0000000000000..b6f44b5bc0b68 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s + +gpu.module @test_distribution { + // CHECK-LABEL: create_nd_tdesc_no_offset + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout> + // CHECK-NOT: xegpu.create_nd_tdesc + %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + gpu.return + } +} diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir new file mode 100644 index 0000000000000..025d48e22307e --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s + +gpu.module @test_distribution { + // CHECK-LABEL: create_nd_tdesc_no_offset + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) { + // CHECK: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout> + %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: create_nd_tdesc_with_ptr + // CHECK-SAME: %[[ARG_0:.*]]: ui64 + gpu.func @create_nd_tdesc_with_ptr(%src: ui64, %w : index, %h : index, %x : index, %y : index) { + // CHECK: xegpu.create_nd_tdesc %[[ARG_0]], shape : [{{.*}}, {{.*}}], strides : [{{.*}}, {{.*}}] : ui64 + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout> + %c1 = arith.constant 1 : index + %tdesc = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides: [%w, %c1] : ui64 + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + gpu.return + } +}