Skip to content

Commit 6992b14

Browse files
committed
xegpu: setDescLayout retains TensorDesc BlockTensorDescAttrs
1 parent 8543b91 commit 6992b14

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,12 @@ xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
125125
xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
126126
xegpu::CreateNdDescOp descOp,
127127
xegpu::LayoutAttr layout) {
128-
auto oldTensorDesc = descOp.getResult();
129-
auto descShapedType = cast<ShapedType>(oldTensorDesc.getType());
128+
auto oldTensorDesc = descOp.getType();
130129
auto descType = xegpu::TensorDescType::get(
131-
descShapedType.getShape(), descShapedType.getElementType(),
132-
/*array_length=*/1,
133-
/*boundary_check=*/true,
134-
/*memory_space=*/xegpu::MemorySpace::Global,
130+
oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
131+
/*array_length=*/oldTensorDesc.getArrayLength(),
132+
/*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
133+
/*memory_space=*/oldTensorDesc.getMemorySpace(),
135134
/*layout=*/layout);
136135

137136
rewriter.setInsertionPointAfter(descOp);

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
// CHECK-LABEL: @set_desc_layout
44
func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
55
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
6+
// CHECK-SAME: #xegpu.block_tdesc_attr<boundary_check = false>
67
// CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
7-
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
8+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
89
return
910
}
1011

0 commit comments

Comments
 (0)