Skip to content

Commit 0e00bc4

Browse files
authored
[mlir][xegpu] cleanup the print format for TensorDesc (#149182)
1 parent 22b0835 commit 0e00bc4

File tree

8 files changed

+47
-47
lines changed

8 files changed

+47
-47
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
6464
)>
6565
];
6666

67+
let extraClassDeclaration = [{
68+
// return true if all fields of the BlockTensorDescAttr are set with
69+
// default values.
70+
bool hasDefaultsOnly();
71+
}];
72+
6773
}
6874

6975
def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> {

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -131,62 +131,48 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
131131
return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
132132
}
133133

134-
BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
135-
return llvm::dyn_cast_if_present<BlockTensorDescAttr>(getEncoding());
136-
}
137-
138-
ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr() const {
139-
return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
134+
template <typename T,
135+
typename = std::enable_if_t<
136+
std::is_same_v<T, BlockTensorDescAttr> ||
137+
std::is_same_v<T, ScatterTensorDescAttr>>>
138+
T getEncodingOfType() const {
139+
return llvm::dyn_cast_if_present<T>(getEncoding());
140140
}
141141

142142
LayoutAttr getLayoutAttr() const {
143143
return llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
144144
}
145145

146146
xegpu::MemorySpace getMemorySpace() const {
147-
auto block_attr = getEncodingAsBlockTensorDescAttr();
148-
if (block_attr && block_attr.getMemorySpace())
149-
return block_attr.getMemorySpace().getValue();
150-
151-
auto scatter_attr = getEncodingAsScatterTensorDescAttr();
152-
if (scatter_attr && scatter_attr.getMemorySpace())
153-
return scatter_attr.getMemorySpace().getValue();
147+
if (auto attr = getEncodingOfType<BlockTensorDescAttr>())
148+
return attr.getMemorySpace().getValue();
154149

155-
// return default value
156-
return MemorySpace::Global;
150+
auto attr = getEncodingOfType<ScatterTensorDescAttr>();
151+
return attr.getMemorySpace().getValue();
157152
}
158153

159154
// get the ArrayLength for blocked TensorDesc
160155
int getArrayLength() {
161-
auto attr = getEncoding();
162-
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
163-
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
164-
if (block_attr && block_attr.getArrayLength())
165-
return block_attr.getArrayLength().getInt();
166-
// return default value
167-
return 1;
156+
auto attr = getEncodingOfType<BlockTensorDescAttr>();
157+
assert(attr && "invalid on non BlockTensorDescAttr.");
158+
return attr.getArrayLength().getInt();
168159
}
169160

170161
bool getBoundaryCheck() {
171-
auto attr = getEncoding();
172-
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
173-
assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
174-
if (block_attr && block_attr.getBoundaryCheck())
175-
return block_attr.getBoundaryCheck().getValue();
176-
// return default value
177-
return true;
162+
auto attr = getEncodingOfType<BlockTensorDescAttr>();
163+
assert(attr && "invalid on non BlockTensorDescAttr.");
164+
return attr.getBoundaryCheck().getValue();
178165
}
179166

180167
bool isScattered() {
181-
return bool(getEncodingAsScatterTensorDescAttr());
168+
return bool(getEncodingOfType<ScatterTensorDescAttr>());
182169
}
183170

184171
// get the ChunkSize for scattered TensorDesc
185172
int getChunkSizeAsInt() {
186-
auto attr = getEncoding();
187-
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
188-
assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
189-
return scatter_attr.getChunkSizeAsInt();
173+
auto attr = getEncodingOfType<ScatterTensorDescAttr>();
174+
assert(attr && "invalid on non ScatterTensorDescAttr.");
175+
return attr.getChunkSizeAsInt();
190176
}
191177

192178
/// Helper to drop all layout information from the TensorDesc type.

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
112112
return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
113113
}
114114

115+
bool BlockTensorDescAttr::hasDefaultsOnly() {
116+
return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
117+
getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
118+
}
119+
115120
//===----------------------------------------------------------------------===//
116121
// XeGPU_ScatterTensorDescAttr
117122
//===----------------------------------------------------------------------===//
@@ -253,10 +258,11 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
253258
if (parser.parseGreater())
254259
return {};
255260

261+
MLIRContext *ctxt = parser.getContext();
256262
return TensorDescType::getChecked(
257-
[&]() { return parser.emitError(parser.getNameLoc()); },
258-
parser.getContext(), shape, elementType,
259-
encoding.value_or(mlir::Attribute()), layout.value_or(mlir::Attribute()));
263+
[&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
264+
elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
265+
layout.value_or(mlir::Attribute()));
260266
}
261267

262268
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -273,7 +279,9 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
273279

274280
printer << getElementType();
275281

276-
if (auto encoding = getEncoding())
282+
auto encoding = getEncoding();
283+
auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
284+
if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
277285
printer << ", " << encoding;
278286

279287
if (auto layout = getLayout())

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
5454
std::multiplies<int64_t>());
5555

5656
// Case 1: regular loads/stores
57-
auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
57+
auto scatterAttr = tdescTy.getEncodingOfType<ScatterTensorDescAttr>();
5858
if (scatterAttr) {
5959
auto chunkSize = scatterAttr.getChunkSize().getInt();
6060
// Verify if the first dimension of the tensor descriptor shape is

mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
3030
// CHECK-SAME: %[[OFFSET:.+]]: index
3131
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
3232
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
33-
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
33+
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
3434
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
3535
// CHECK: return %[[VEC]]
3636

@@ -55,7 +55,7 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
5555
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
5656
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
5757
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
58-
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
58+
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
5959
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
6060
// CHECK: return %[[VEC]]
6161

@@ -73,7 +73,7 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
7373
// CHECK-SAME: %[[OFFSET:.+]]: index
7474
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
7575
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
76-
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32,
76+
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
7777
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
7878
// CHECK: return %[[VEC]]
7979

mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
3232
// CHECK-SAME: %[[OFFSET:.+]]: index
3333
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
3434
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
35-
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
35+
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
3636
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
3737

3838
// -----
@@ -57,7 +57,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
5757
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
5858
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
5959
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
60-
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
60+
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
6161
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
6262

6363
// -----
@@ -75,7 +75,7 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
7575
// CHECK-SAME: %[[OFFSET:.+]]: index
7676
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
7777
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
78-
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
78+
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
7979
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
8080

8181
// -----

mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
5151
// CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>,
5252
// CHECK-SAME: %[[OFFSET:.+]]: index
5353
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
54-
// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32,
54+
// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
5555
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
5656
// CHECK: return %[[VEC]]
5757

mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
8080
// CHECK-SAME: %[[OFFSET:.+]]: index
8181
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
8282
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
83-
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
83+
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
8484
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
8585

8686
// -----

0 commit comments

Comments
 (0)