Skip to content

Commit 7717fa7

Browse files
committed
change print format
1 parent 55b417a commit 7717fa7

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
6464
)>
6565
];
6666

67+
let extraClassDeclaration = [{
68+
// return true if all fields of the BlockTensorDescAttr are set with
69+
// default values.
70+
bool hasDefaultsOnly();
71+
}];
72+
6773
}
6874

6975
def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> {

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
112112
return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
113113
}
114114

115+
bool BlockTensorDescAttr::hasDefaultsOnly() {
116+
return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
117+
getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
118+
}
119+
115120
//===----------------------------------------------------------------------===//
116121
// XeGPU_ScatterTensorDescAttr
117122
//===----------------------------------------------------------------------===//
@@ -253,10 +258,11 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
253258
if (parser.parseGreater())
254259
return {};
255260

261+
MLIRContext *ctxt = parser.getContext();
256262
return TensorDescType::getChecked(
257-
[&]() { return parser.emitError(parser.getNameLoc()); },
258-
parser.getContext(), shape, elementType,
259-
encoding.value_or(mlir::Attribute()), layout.value_or(mlir::Attribute()));
263+
[&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
264+
elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
265+
layout.value_or(mlir::Attribute()));
260266
}
261267

262268
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -273,7 +279,9 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
273279

274280
printer << getElementType();
275281

276-
if (auto encoding = getEncoding())
282+
auto encoding = getEncoding();
283+
auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
284+
if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
277285
printer << ", " << encoding;
278286

279287
if (auto layout = getLayout())

0 commit comments

Comments
 (0)