@@ -112,6 +112,11 @@ BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
112112 return Base::get (context, scopeAttr, lengthAttr, boundaryAttr);
113113}
114114
115+ bool BlockTensorDescAttr::hasDefaultsOnly () {
116+ return getMemorySpace ().getValue () == xegpu::MemorySpace::Global &&
117+ getArrayLength ().getInt () == 1 && getBoundaryCheck ().getValue ();
118+ }
119+
115120// ===----------------------------------------------------------------------===//
116121// XeGPU_ScatterTensorDescAttr
117122// ===----------------------------------------------------------------------===//
@@ -253,10 +258,11 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
253258 if (parser.parseGreater ())
254259 return {};
255260
261+ MLIRContext *ctxt = parser.getContext ();
256262 return TensorDescType::getChecked (
257- [&]() { return parser.emitError (parser.getNameLoc ()); },
258- parser. getContext (), shape, elementType ,
259- encoding. value_or ( mlir::Attribute ()), layout.value_or (mlir::Attribute ()));
263+ [&]() { return parser.emitError (parser.getNameLoc ()); }, ctxt, shape,
264+ elementType, encoding. value_or ( BlockTensorDescAttr::get (ctxt)) ,
265+ layout.value_or (mlir::Attribute ()));
260266}
261267
262268void TensorDescType::print (::mlir::AsmPrinter &printer) const {
@@ -273,7 +279,9 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
273279
274280 printer << getElementType ();
275281
276- if (auto encoding = getEncoding ())
282+ auto encoding = getEncoding ();
283+ auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
284+ if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly ()))
277285 printer << " , " << encoding;
278286
279287 if (auto layout = getLayout ())
0 commit comments