diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index abc291c81a76c..eb54d6887681d 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -272,6 +272,11 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { let builders = [ OpBuilder<(ins "Value": $TensorDesc, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)>, + OpBuilder<(ins "Value": $TensorDesc, + "ArrayRef": $offsets, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> @@ -348,6 +353,12 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ let builders = [ OpBuilder<(ins "Type": $value, "Value": $TensorDesc, + "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)>, + OpBuilder<(ins "Type": $value, "Value": $TensorDesc, + "ArrayRef": $offsets, "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, @@ -419,7 +430,12 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ OpBuilder<(ins "Value": $value, "Value": $TensorDesc, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, - "xegpu::CachePolicyAttr": $l3_hint)> + "xegpu::CachePolicyAttr": $l3_hint)>, + OpBuilder<(ins "Value": $value, "Value": $TensorDesc, + "ArrayRef": $offsets, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> ]; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index eee0fdc7160de..906c71d8b8dad 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -385,6 +385,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, l1_hint, l2_hint, l3_hint); } +void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, + Value tensorDesc, ArrayRef offsets, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector dynamicOffsets; + SmallVector staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, + l2_hint, l3_hint); +} + LogicalResult PrefetchNdOp::verify() { auto tdescTy = getTensorDescType(); if (tdescTy.isScattered()) @@ -427,6 +442,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, l3_hint); } +void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, + Value tensorDesc, ArrayRef offsets, + UnitAttr packed, DenseI64ArrayAttr transpose, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector dynamicOffsets; + SmallVector staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, + packed, transpose, l1_hint, l2_hint, l3_hint); +} + LogicalResult LoadNdOp::verify() { auto tdescTy = getTensorDescType(); auto valueTy = getType(); @@ -533,6 +564,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); } +void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, + Value tensorDesc, ArrayRef offsets, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector dynamicOffsets; + SmallVector staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, + l1_hint, l2_hint, l3_hint); +} + LogicalResult StoreNdOp::verify() { auto dstTy = getTensorDescType(); // Tile auto valTy = getValueType(); // Vector diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index ecec186fe3fc9..8f1208e77ca5d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -182,16 +182,16 @@ struct WgToSgCreateNdOp : public OpConversionPattern { layout.dropSgLayoutAndData()); SmallVector newCreateNdOps; - SmallVector wgOffsets = op.getMixedOffsets(); + SmallVector origOffsets = op.getMixedOffsets(); for (auto tdescOffsets : *maybeTdescOffsets) { SmallVector sgOffsets; size_t rank = tdescOffsets.size(); for (size_t i = 0; i < rank; i++) { - size_t idx = wgOffsets.size() - rank + i; + size_t idx = origOffsets.size() - rank + i; Value add = rewriter.createOrFold( loc, tdescOffsets[i], - getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx])); + getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx])); sgOffsets.push_back(add); } @@ -296,6 +296,205 @@ struct WgToSgStoreNdOp : public OpConversionPattern { } }; +// Utility function to compute global offsets for subgroup operations. +// Returns a vector of new offsets for each subgroup, given the original op's +// offsets and subgroup relative offsets. +static SmallVector> +computeOffsets(Operation *op, ArrayRef> sgOffsetsList, + ArrayRef origOffsets, + ConversionPatternRewriter &rewriter) { + SmallVector> finalOffsets; + Location loc = op->getLoc(); + for (const auto &sgOffsets : sgOffsetsList) { + SmallVector newOffsets; + size_t rank = sgOffsets.size(); + for (size_t i = 0; i < rank; i++) { + size_t idx = origOffsets.size() - rank + i; + Value add = rewriter.createOrFold( + loc, sgOffsets[i], + getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx])); + newOffsets.push_back(add); + } + finalOffsets.push_back(std::move(newOffsets)); + } + return finalOffsets; +} + +// Utility function to get sgShape, sgOffsetList for a given +// op. +template +LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor, + ConversionPatternRewriter &rewriter, + SmallVector &sgShape, + SmallVector> &sgOffsetList) { + int64_t offsetSize = static_cast(op.getOffsets().size()); + if (offsetSize == 0 && (!op.getConstOffsetsAttr())) + return failure(); + + Location loc = op.getLoc(); + Value tdesc = op.getTensorDesc(); + auto tdescTy = dyn_cast(tdesc.getType()); + if (!tdescTy) + return failure(); + auto layout = dyn_cast(tdescTy.getLayout()); + if (!layout) + return failure(); + + SmallVector sgLayout; + auto sgLayoutAttr = layout.getSgLayout(); + if (!sgLayoutAttr) + return rewriter.notifyMatchFailure( + op, "sgLayout attribute is required in layout"); + sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); + + ArrayRef wgShape = tdescTy.getShape(); + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); + + // Get the subgroup ID + Value linearSgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + + int64_t startOfRange = -1, endOfRange = -1; + bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange); + + if (sgIdRangeSpecified) { + int64_t sgCount = endOfRange - startOfRange; + if (computeProduct(sgLayout) != sgCount) + return rewriter.notifyMatchFailure( + op, "sg_layout size must match the sg_id_range"); + Value startOfRangeVal = + rewriter.create(loc, startOfRange); + linearSgId = + rewriter.createOrFold(loc, linearSgId, startOfRangeVal); + } + + auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + sgOffsetList = *sgOffsets; + return success(); +} + +template +SmallVector getOffsets(OpTy op, + ConversionPatternRewriter &rewriter) { + SmallVector origOffsets; + if (auto constOffsets = op.getConstOffsetsAttr()) { + for (auto attr : constOffsets.asArrayRef()) + origOffsets.push_back(rewriter.getIndexAttr(attr)); + } + for (auto v : op.getOffsets()) + origOffsets.push_back(v); + return origOffsets; +} + +// This pattern transforms the LoadNdOp with explicit offsets to load +// subgroup data. +struct WgToSgLoadNdOpWithOffset : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector sgShape; + SmallVector> sgOffsetList; + + // Do the distribution from workgroup to subgroup and get subgroup offsets + if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + return failure(); + + // Get the original workgroup offsets + SmallVector origOffsets = getOffsets(op, rewriter); + + // Calculate the final offsets for each subgroup + auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); + + SmallVector newLoadOps; + for (auto [offsets, tdesc] : + llvm::zip(finalOffsets, adaptor.getTensorDesc())) { + VectorType newResTy = VectorType::get( + sgShape, + dyn_cast(tdesc.getType()).getElementType()); + auto newLoadOp = rewriter.create( + op.getLoc(), newResTy, tdesc, offsets, + /*packed=*/nullptr, + /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + newLoadOps.push_back(newLoadOp); + } + rewriter.replaceOpWithMultiple(op, {newLoadOps}); + return success(); + } +}; + +// This pattern transforms the StoreNdOp with explicit offsets to store +// subgroup data. +struct WgToSgStoreNdOpWithOffset + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector sgShape; + SmallVector> sgOffsetList; + + // Do the distribution from workgroup to subgroup and get subgroup offsets + if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + return failure(); + + // Get the original workgroup offsets + SmallVector origOffsets = getOffsets(op, rewriter); + + // Calculate the final offsets for each subgroup + auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); + + for (auto [offsets, tdesc, value] : + llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) { + rewriter.create(op.getLoc(), value, tdesc, offsets, + op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + } + rewriter.eraseOp(op); + return success(); + } +}; + +// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch +// subgroup data. +struct WgToSgPrefetchNdOpWithOffset + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector sgShape; + SmallVector> sgOffsetList; + + // Do the distribution from workgroup to subgroup and get subgroup offsets + if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) + return failure(); + + // Get the original workgroup offsets + SmallVector origOffsets = getOffsets(op, rewriter); + + // Calculate the final offsets for each subgroup + auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter); + + for (auto [offsets, tdesc] : + llvm::zip(finalOffsets, adaptor.getTensorDesc())) { + rewriter.create( + op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + } + rewriter.eraseOp(op); + return success(); + } +}; + /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the /// offsets of the new subgroup src tensor descriptors. @@ -690,12 +889,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern { namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns + .add(patterns.getContext()); } } // namespace xegpu } // namespace mlir diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index b6f44b5bc0b68..6ff7a94d678a3 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -10,5 +10,76 @@ gpu.module @test_distribution { %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> gpu.return - } + } + + // CHECK-LABEL: load_nd_tdesc_with_offset + gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] + // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout> + // CHECK-SAME-COUNT-4: -> vector<16x16xf32> + // CHECK-NOT: xegpu.load_nd + %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + gpu.return + } + + // CHECK-LABEL: store_nd_with_offset + gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] + // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout> + // CHECK-NOT: xegpu.store_nd + %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + xegpu.store_nd %load, %tdesc[0, 0] + : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: prefetch_nd_tdesc_with_offset + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}] + // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + // CHECK-NOT: xegpu.prefetch_nd + %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + xegpu.prefetch_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: dpas + // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>) + gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> + // CHECK-NOT: xegpu.create_nd_tdesc + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> + // CHECK-NOT: xegpu.create_nd_tdesc + // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} + // CHECK-SAME-COUNT-16: {layout = #xegpu.layout} + // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> + // CHECK-NOT: xegpu.dpas + %tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16> + -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a[0, 0] + : !xegpu.tensor_desc<256x128xf16, #xegpu.layout> + -> vector<256x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x256xf16> + -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout> + %load_b = xegpu.load_nd %tdesc_b[0, 0] + : !xegpu.tensor_desc<128x256xf16, #xegpu.layout> + -> vector<128x256xf16> + %dpas = xegpu.dpas %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 025d48e22307e..07a0b86223c33 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s +//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)> +//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)> gpu.module @test_distribution { // CHECK-LABEL: create_nd_tdesc_no_offset // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> @@ -21,4 +23,244 @@ gpu.module @test_distribution { -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> gpu.return } + + // CHECK-LABEL: load_nd_tdesc_with_offset + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) { + //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index + //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]] + //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]] + //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout> -> vector<32x32xf32> + %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + gpu.return + } + + // CHECK-LABEL: store_nd_with_offsets + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) { + //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index + //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]] + //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]] + //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout> + %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + xegpu.store_nd %load, %tdesc[0, 0] + : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + gpu.return +} + + // CHECK-LABEL: prefetch_nd_tdesc_with_offset + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) { + //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index + //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]] + //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]] + //CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout> + %cst0 = arith.constant 0 : index + %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + xegpu.prefetch_nd %tdesc[%cst0, %cst0] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: dpas + gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32> + %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a[0, 0] + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + -> vector<128x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + %load_b = xegpu.load_nd %tdesc_b[0, 0] + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + -> vector<128x128xf16> + %dpas = xegpu.dpas %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> + gpu.return + } + + // CHECK-LABEL: dpas_no_sg_data + gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> + %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a[0, 0] + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + -> vector<128x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + %load_b = xegpu.load_nd %tdesc_b[0, 0] + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + -> vector<128x128xf16> + %dpas = xegpu.dpas %load_a, %load_b + {layout_result_0 = #xegpu.layout} + : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> + gpu.return + } + + // CHECK-LABEL: dpas_with_no_create_nd_desc + gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) { + // CHECK-NOT: vector<32x32xf32> + %dpas = xegpu.dpas %a, %b + {layout = #xegpu.layout} + : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32> + gpu.return + } + + // CHECK-LABEL: broadcast_dim1 + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32> + gpu.func @broadcast_dim1(%src: memref<256x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32> + -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x1xf32, #xegpu.layout> + -> vector<256x1xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32> + %broadcast = vector.broadcast %load + {layout_result_0 = #xegpu.layout} + : vector<256x1xf32> to vector<256x32xf32> + gpu.return + } + + // CHECK-LABEL: broadcast_dim0 + // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32> + gpu.func @broadcast_dim0(%src: memref<1x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32> + -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<1x128xf32, #xegpu.layout> + -> vector<1x128xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32> + %broadcast = vector.broadcast %load + {layout_result_0 = #xegpu.layout} + : vector<1x128xf32> to vector<32x128xf32> + gpu.return + } + + // CHECK-LABEL: gemm_with_load_store_offset + // CHECK-SAME: %[[ARG_0:.*]]: memref<1024x1024xf16>, %[[ARG_1:.*]]: memref<1024x1024xf16>, %[[ARG_2:.*]]: memref<1024x1024xf32> + gpu.func @gemm_with_load_store_offset(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) { + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[c1024:%.+]] = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c128 : index + %1 = arith.muli %block_id_y, %c128 : index + %2 = xegpu.create_nd_tdesc %arg2 : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout> + // CHECK: [[DESC_A:%.+]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x128xf16> + // CHECK: [[DESC_B:%.+]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x16xf16> + %3 = xegpu.create_nd_tdesc %arg0 : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + %4 = xegpu.create_nd_tdesc %arg1 : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout> + // load_nd with offset + %5 = xegpu.load_nd %2[%0, %1] : !xegpu.tensor_desc<128x128xf32, #xegpu.layout> -> vector<128x128xf32> + %6 = xegpu.load_nd %3[%0, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> -> vector<128x128xf16> + %7 = xegpu.load_nd %4[%c0, %1] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> -> vector<128x128xf16> + // scf.for loop + // CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] + // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> + // CHECK-SAME: (vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32>) + // CHECK: [[c:%.+]] = xegpu.dpas [[arg4]], [[arg5]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32> + // CHECK: [[a:%.+]] = xegpu.load_nd [[DESC_A]][{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16> + // CHECK: [[b:%.+]] = xegpu.load_nd [[DESC_B]][{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16> + // CHECK: scf.yield [[a]], [[b]], [[c]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> + %8:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %6, %arg5 = %7, %arg6 = %5) + -> (vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32>) { + // load_nd with offset inside loop + %9 = xegpu.dpas %arg4, %arg5, %arg6 {layout_result_0 = #xegpu.layout} + : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> + %10 = xegpu.load_nd %3[%arg3, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> -> vector<128x128xf16> + %11 = xegpu.load_nd %4[%c0, %arg3] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout> -> vector<128x128xf16> + scf.yield %10, %11, %9 : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> + } + // store_nd with offset + xegpu.store_nd %8#2, %2[%0, %1] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: @subgroup_id_range + gpu.func @subgroup_id_range(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) { + %sg_id = gpu.subgroup_id : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c31 = arith.constant 31 : index + %c3 = arith.constant 3 : index + %cond1 = arith.cmpi sge, %sg_id, %c0 : index + %cond2 = arith.cmpi slt, %sg_id, %c1 : index + %cond = arith.andi %cond1, %cond2 : i1 + scf.if %cond { + // CHECK-NOT: index.sub + %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + } {sg_id_range = #xegpu.range<[0, 32]>} + %cond3 = arith.cmpi sge, %sg_id, %c2 : index + %cond4 = arith.cmpi slt, %sg_id, %c31 : index + %cond5 = arith.andi %cond3, %cond4 : i1 + scf.if %cond5 { + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]] + %tdesc = xegpu.create_nd_tdesc %src2 : memref<128x64xf32> + -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + -> vector<128x64xf32> + %exp = math.exp %load {layout_result_0 = #xegpu.layout} : vector<128x64xf32> + }{sg_id_range = #xegpu.range<[2, 18]>} + gpu.return + } + + // CHECK-LABEL: @subgroup_id_range_nested_if + gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) { + %sg_id = gpu.subgroup_id : index + %c1 = arith.constant 1 : i1 + %c3 = arith.constant 3 : index + %c32 = arith.constant 32 : index + %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + %cond1 = arith.cmpi sge, %sg_id, %c3 : index + %cond2 = arith.cmpi slt, %sg_id, %c32 : index + %cond = arith.andi %cond1, %cond2 : i1 + scf.if %c1 { + scf.if %cond { + // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]] + %td = xegpu.create_nd_tdesc %src1 : memref<128x64xf32> + -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + %ld = xegpu.load_nd %td[0, 0] + : !xegpu.tensor_desc<128x64xf32, #xegpu.layout> + -> vector<128x64xf32> + %exp = math.exp %ld {layout_result_0 = #xegpu.layout} : vector<128x64xf32> + } + } {sg_id_range = #xegpu.range<[3, 19]>} + gpu.return + } }