diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 73f9061f5debe..9e80b8e453a73 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -843,7 +843,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr:$chunk_size, OptionalAttr:$l1_hint, OptionalAttr:$l2_hint, - OptionalAttr:$l3_hint); + OptionalAttr:$l3_hint, + OptionalAttr:$layout); let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value); let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -852,6 +853,16 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { return getSource().getType(); } + xegpu::DistributeLayoutAttr getDistributeLayout() { + xegpu::DistributeLayoutAttr layout = nullptr; + if (auto tdescType = getTensorDescType()) { + layout = tdescType.getLayoutAttr(); + } + if (!layout) + layout = getLayoutAttr(); + return layout; + } + TypedValue getTensorDesc() { if (auto tdescType = getTensorDescType()) { return llvm::cast>(getSource()); @@ -895,7 +906,19 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { "IntegerAttr": $chunk_size, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, - "xegpu::CachePolicyAttr": $l3_hint)> + "xegpu::CachePolicyAttr": $l3_hint)>, + OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint, + "xegpu::DistributeLayoutAttr": $layout)>, + OpBuilder<(ins "Type": $value, "Value": $source, + "ArrayRef": $offsets, "Value": $mask, + "IntegerAttr": $chunk_size, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint, + "xegpu::DistributeLayoutAttr": $layout)> ]; let hasVerifier = 1; @@ -979,7 +1002,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr:$chunk_size, OptionalAttr:$l1_hint, OptionalAttr:$l2_hint, - OptionalAttr:$l3_hint); + OptionalAttr:$l3_hint, + OptionalAttr:$layout); let extraClassDeclaration = extraBaseClassDeclaration#[{ Type getDestType() { @@ -993,6 +1017,16 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { return TypedValue(); } + xegpu::DistributeLayoutAttr getDistributeLayout() { + xegpu::DistributeLayoutAttr layout = nullptr; + if (auto tdescType = getTensorDescType()) { + layout = tdescType.getLayoutAttr(); + } + if (!layout) + layout = getLayoutAttr(); + return layout; + } + xegpu::TensorDescType getTensorDescType() { return dyn_cast(getDestType()); } @@ -1030,7 +1064,19 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { "IntegerAttr": $chunk_size, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, - "xegpu::CachePolicyAttr": $l3_hint)> + "xegpu::CachePolicyAttr": $l3_hint)>, + OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint, + "xegpu::DistributeLayoutAttr": $layout)>, + OpBuilder<(ins "Value": $value, "Value": $dest, + "ArrayRef": $offsets, "Value": $mask, + "IntegerAttr": $chunk_size, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint, + "xegpu::DistributeLayoutAttr": $layout)> ]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788d0b9b4..fa42ffbad4d31 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -816,7 +816,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, valueType, source, Value(), mask, IntegerAttr(), - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*layout=*/nullptr); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, @@ -832,7 +832,34 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*layout=*/nullptr); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint, + DistributeLayoutAttr layout) { + build(builder, state, valueType, source, Value(), mask, IntegerAttr(), + l1_hint, l2_hint, l3_hint, layout); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, + ArrayRef offsets, Value mask, + IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint, + DistributeLayoutAttr layout) { + auto loc = source.getLoc(); + int64_t size = static_cast(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, + l2_hint, l3_hint, layout); } //===----------------------------------------------------------------------===// @@ -883,7 +910,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*layout=*/nullptr); } void StoreScatterOp::build(OpBuilder &builder, OperationState &state, @@ -901,7 +928,33 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, // Call the correct builder overload that does not expect result types. build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, - l3_hint); + l3_hint, /*layout=*/nullptr); +} + +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint, + DistributeLayoutAttr layout) { + build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, + l2_hint, l3_hint, layout); +} + +void StoreScatterOp::build( + OpBuilder &builder, OperationState &state, Value value, Value dest, + ArrayRef offsets, Value mask, IntegerAttr chunk_size, + xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) { + auto loc = dest.getLoc(); + int64_t size = static_cast(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + // Call the correct builder overload that does not expect result types. + build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, + l3_hint, layout); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a178d0fe4b0b0..fad698bd2001f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -687,7 +687,7 @@ struct UnrollLoadGatherOpWithOffset auto newOp = xegpu::LoadGatherOp::create( rewriter, loc, newValueTy, op.getSource(), o, m, rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), op.getLayoutAttr()); newOps.push_back(newOp); } @@ -783,7 +783,7 @@ struct UnrollStoreScatterOpWithOffsets xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m, rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), op.getLayoutAttr()); } rewriter.eraseOp(op); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index c28d2fc6c2b63..33bf2808abd01 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -913,7 +913,8 @@ struct WgToSgLoadGatherOpWithOffset llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { auto newLoadOp = xegpu::LoadGatherOp::create( rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, - op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), + /*layout*/ nullptr); xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), layout.dropSgLayoutAndData()); newLoadOps.push_back(newLoadOp); @@ -961,19 +962,16 @@ struct WgToSgStoreScatterOpWithOffset auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); for (auto [val, offs, mask] : llvm::zip( adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { + xegpu::DistributeLayoutAttr newLayout = nullptr; + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) + // Update the layout attribute to drop sg_layout and sg_data. + newLayout = layout.dropSgLayoutAndData(); + auto store = xegpu::StoreScatterOp::create( rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr, - op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); - // Update the layout attribute to drop sg_layout and sg_data. - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - for (OpOperand &operand : store->getOpOperands()) { - // Skip for operand one (memref) - if (operand.getOperandNumber() == 1) - continue; - xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); - } - } + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), + /*layout*/ newLayout); } rewriter.eraseOp(op); return success(); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 2c56a438ea62c..c9ef37dc1daa9 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -109,6 +109,8 @@ std::string xegpu::getLayoutName(const OpOperand &operand) { } std::string xegpu::getLayoutName(const OpResult result) { + if (isa(result.getOwner())) + return "layout"; const StringRef prefix = "layout_result_"; return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str(); } @@ -141,6 +143,9 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { if (auto storeOp = dyn_cast(defOp)) return storeOp.getLayoutAttr(); + if (auto loadGatherOp = dyn_cast(defOp)) + return loadGatherOp.getLayoutAttr(); + std::string layoutName = getLayoutName(result); if (defOp->hasAttr(layoutName)) return defOp->getAttrOfType(layoutName); @@ -168,6 +173,12 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) { if (auto storeOp = dyn_cast(op)) return storeOp.getLayoutAttr(); + // if (auto loadGatherOp = dyn_cast(op)) + // return loadGatherOp.getDistributeLayout(); + + if (auto storeScatterOp = dyn_cast(op)) + return storeScatterOp.getDistributeLayout(); + std::string layoutName = xegpu::getLayoutName(opr); if (op->hasAttr(layoutName)) return op->getAttrOfType(layoutName); @@ -196,7 +207,8 @@ template void xegpu::setDistributeLayoutAttr( void xegpu::setDistributeLayoutAttrs( Operation *op, function_ref getLayoutImpl) { op->walk([&](Operation *nestOp) { - if (isa(nestOp)) + if (isa(nestOp)) return; for (OpOperand &opr : nestOp->getOpOperands()) { @@ -216,6 +228,9 @@ void xegpu::removeLayoutAttr(const T &operandOrResult) { std::string name = xegpu::getLayoutName(operandOrResult); if (owner->hasAttrOfType(name)) owner->removeAttr(name); + if (isa(owner) && + owner->hasAttrOfType("layout")) + owner->removeAttr("layout"); } // Explicit instantiation for OpResult diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index 30f785ded975a..f807bcbda1984 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -97,7 +97,7 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> -> // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout> -// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] {layout_result_0 = #xegpu.layout} +// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout}> // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr, #xegpu.layout>, vector<16xi1> -> vector<16x16xf16> func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index @@ -122,7 +122,7 @@ func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> // CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> -> // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout> -// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] {layout_result_0 = #xegpu.layout} : +// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] <{layout = #xegpu.layout}> : // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout>, vector<16xi1> -> vector<16xf32> func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) { %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> @@ -167,8 +167,8 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) { // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> -// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> -// CHECK-SAME: {layout_result_0 = #xegpu.layout} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, +// CHECK-SAME: layout = #xegpu.layout}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> // CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> func.func @scatter_ops_chunksize(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> @@ -186,7 +186,7 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) { // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> // CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> +// CHECK-SAME: <{layout = #xegpu.layout}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> // CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> func.func @scatter_ops(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 0e1365aa64171..f8e8794b4e066 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -136,9 +136,9 @@ gpu.module @xevm_module{ %1 = arith.constant {layout_result_0 = #xegpu.layout} dense<1>: vector<16xi1> %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> %loaded = scf.if %pred -> (vector<16x8xf16>) { - %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> { - layout_result_0 = #xegpu.layout - } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, + layout = #xegpu.layout + }> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> scf.yield %3 : vector<16x8xf16> } else { %3 = arith.constant { @@ -168,9 +168,9 @@ gpu.module @xevm_module{ %1 = arith.constant {layout_result_0 = #xegpu.layout} dense<1>: vector<16xi1> %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> scf.if %pred { - %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> { - layout_result_0 = #xegpu.layout - } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, + layout = #xegpu.layout + }> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> } gpu.return diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 742d11f8052ec..888cbddebf2f0 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -274,7 +274,7 @@ gpu.module @test_distribution { // CHECK-SAME: : memref, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16> %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256x16xindex> %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256x16xi1> - %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} + %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : memref, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16> gpu.return } @@ -285,16 +285,13 @@ gpu.module @test_distribution { // CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<2.550000e+01> : vector<8xf16> // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<8xindex> // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<8xi1> - // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> - // CHECK-SAME: {layout_operand_0 = #xegpu.layout, layout_operand_2 = #xegpu.layout, - // CHECK-SAME: layout_operand_3 = #xegpu.layout} + // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint, + // CHECK-SAME: layout = #xegpu.layout}> // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1> %val = arith.constant {layout_result_0 = #xegpu.layout} dense<25.5> : vector<256xf16> %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256xindex> %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256xi1> - xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout, - layout_operand_2 = #xegpu.layout, - layout_operand_3 = #xegpu.layout, + xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1> gpu.return @@ -309,7 +306,7 @@ gpu.module @test_distribution { // CHECK-SAME: : memref, vector<8xindex>, vector<8xi1> -> vector<8x4xf16> %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256xindex> %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256xi1> - %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout, l1_hint = #xegpu.cache_hint} + %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} : memref, vector<256xindex>, vector<256xi1> -> vector<256x4xf16> gpu.return } @@ -401,7 +398,7 @@ gpu.module @test_distribution { %cst_acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [3]>} dense<0.0> : vector<4x2x6xf16> %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<4x2x6x32xindex> %mask = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<4x2x6x32xi1> - %load = xegpu.load %src[%offset], %mask {layout_result_0 = #xegpu.layout} : ui64, vector<4x2x6x32xindex>, vector<4x2x6x32xi1> -> vector<4x2x6x32xf16> + %load = xegpu.load %src[%offset], %mask {layout = #xegpu.layout} : ui64, vector<4x2x6x32xindex>, vector<4x2x6x32xi1> -> vector<4x2x6x32xf16> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [3] : vector<1x1x1x32xf16> to vector<1x1x1xf16> %reduce = vector.multi_reduction , %load, %cst_acc {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [3]>} [3] : vector<4x2x6x32xf16> to vector<4x2x6xf16>