Skip to content

Commit 205fea7

Browse files
committed
address feedback
1 parent d3e935b commit 205fea7

File tree

3 files changed

+4
-12
lines changed

3 files changed

+4
-12
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
181181

182182
int64_t rank = getMixedSizes().size();
183183

184-
// Set constant offset to MAX to indicate no offsets provided
185-
// or else the printer can't differeiate this with valid const_offset value (say 0)
186-
setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max()));
184+
setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
187185

188186
attr = getConstOffsetsAttr();
189187
return attr;

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,12 +334,6 @@ void printOptionalDynamicIndexList(
334334
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
335335
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
336336

337-
if (values.empty() && llvm::all_of(integers, [](int64_t i) {
338-
// MAX indiates no user-provided offsets for CreateNdDescOp.
339-
return i == std::numeric_limits<int64_t>::max();
340-
}))
341-
return;
342-
343337
return printDynamicIndexList(printer, op, values, integers,
344338
/*scalableFlags=*/{}, valueTypes, delimiter);
345339
}

mlir/test/Dialect/XeGPU/ops.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index,
6767
//CHECK: %[[C:.*]] = arith.constant 1 : index
6868
%c1 = arith.constant 1 : index
6969

70-
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
70+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
7171
%3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
7272

7373
gpu.return
@@ -77,7 +77,7 @@ gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index,
7777
gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
7878

7979
%c1 = arith.constant 1 : index
80-
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
80+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
8181
%2 = xegpu.create_nd_tdesc %src, shape : [%h, %w], strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
8282

8383
gpu.return
@@ -97,7 +97,7 @@ gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index,
9797
// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
9898
gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
9999
%c1 = arith.constant 1 : index
100-
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
100+
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
101101
%2 = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
102102

103103
gpu.return

0 commit comments

Comments
 (0)