@@ -55,77 +55,6 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
5555 return Base::get (context, scopeAttr, chunkSizeAttr);
5656}
5757
58- // ===----------------------------------------------------------------------===//
59- // XeGPU_SGMapAttr
60- // ===----------------------------------------------------------------------===//
61- namespace {
62- template <typename T, unsigned N>
63- LogicalResult parseIntArrayField (::mlir::AsmParser &parser,
64- llvm::SmallVector<T, N> &result,
65- llvm::StringRef fieldName) {
66- if (failed (parser.parseKeyword (fieldName))) {
67- parser.emitError (parser.getCurrentLocation (),
68- " unexpected field name. Expected " + fieldName + " ." );
69- return failure ();
70- }
71-
72- if (failed (parser.parseEqual ())) {
73- parser.emitError (parser.getCurrentLocation (), " expected '=' sign." );
74- return failure ();
75- }
76-
77- auto elemParser = [&]() -> llvm::ParseResult {
78- uint32_t elem = 0 ;
79- auto res = parser.parseInteger (elem);
80- result.push_back (elem);
81- return res;
82- };
83-
84- return parser.parseCommaSeparatedList (AsmParser::Delimiter::Square,
85- elemParser, fieldName);
86- }
87- } // namespace
88-
89- mlir::Attribute SGMapAttr::parse (::mlir::AsmParser &parser,
90- ::mlir::Type attrType) {
91- if (failed (parser.parseLess ()))
92- return {};
93-
94- llvm::SmallVector<uint32_t , 2 > wi_layout, wi_data;
95- if (failed (parseIntArrayField (parser, wi_layout, " wi_layout" )))
96- return {};
97-
98- if (failed (parser.parseComma ()))
99- return {};
100-
101- if (failed (parseIntArrayField (parser, wi_data, " wi_data" )))
102- return {};
103-
104- return SGMapAttr::getChecked (
105- [&]() { return parser.emitError (parser.getNameLoc ()); },
106- parser.getContext (), wi_layout, wi_data);
107- }
108-
109- void SGMapAttr::print (::mlir::AsmPrinter &printer) const {
110- printer << " <" ;
111- printer.printKeywordOrString (" wi_layout" );
112- printer << " = [" << getWiLayout () << " ], " ;
113- printer.printKeywordOrString (" wi_data" );
114- printer << " = [" << getWiData () << " ]" ;
115- printer << " >" ;
116- }
117-
118- LogicalResult
119- SGMapAttr::verify (llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
120- llvm::ArrayRef<uint32_t > wi_layout,
121- llvm::ArrayRef<uint32_t > wi_data) {
122- if (wi_layout.size () != 2 )
123- return emitError () << " expected wi_layout of size 2" ;
124- if (wi_data.size () != 2 )
125- return emitError () << " expected wi_data of size 2" ;
126- return success ();
127- }
128-
12958// ===----------------------------------------------------------------------===//
13059// XeGPU_TensorDescType
13160// ===----------------------------------------------------------------------===//
@@ -134,7 +63,6 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
13463 llvm::SmallVector<int64_t > shape;
13564 mlir::Type elementType;
13665 mlir::FailureOr<mlir::Attribute> encoding;
137- mlir::FailureOr<mlir::Attribute> sg_map;
13866
13967 // Parse literal '<'
14068 if (parser.parseLess ())
@@ -153,31 +81,22 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
15381 }
15482
15583 // parse optional attributes
156- while (mlir::succeeded (parser.parseOptionalComma ())) {
157- mlir::Attribute attr;
158- ParseResult res = parser.parseAttribute (attr);
159- if (mlir::succeeded (res)) {
160- if (mlir::isa<SGMapAttr>(attr)) {
161- sg_map = attr;
162- continue ;
163- }
164- if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
165- encoding = attr;
166- continue ;
167- }
84+ if (mlir::succeeded (parser.parseOptionalComma ())) {
85+ encoding = mlir::FieldParser<mlir::Attribute>::parse (parser);
86+ if (mlir::failed (encoding)) {
87+ parser.emitError (
88+ parser.getCurrentLocation (),
89+ " Failed to parse the attribute field for TensorDescType.\n " );
90+ return {};
16891 }
169- parser.emitError (parser.getCurrentLocation (),
170- " Failed to parse the attribute.\n " );
171- return {};
17292 }
17393
17494 // Parse literal '>'
17595 if (parser.parseGreater ())
17696 return {};
17797
17898 return TensorDescType::get (parser.getContext (), shape, elementType,
179- encoding.value_or (mlir::Attribute ()),
180- sg_map.value_or (mlir::Attribute ()));
99+ encoding.value_or (mlir::Attribute ()));
181100}
182101
183102void TensorDescType::print (::mlir::AsmPrinter &printer) const {
@@ -197,30 +116,25 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
197116 if (auto encoding = getEncoding ())
198117 printer << " , " << encoding;
199118
200- if (auto sg_map = getSgMap ())
201- printer << " , " << sg_map;
202-
203119 printer << " >" ;
204120}
205121
206122TensorDescType TensorDescType::get (llvm::ArrayRef<int64_t > shape,
207123 mlir::Type elementType, int array_length,
208124 bool boundary_check,
209- MemorySpace memory_space,
210- mlir::Attribute sg_map) {
125+ MemorySpace memory_space) {
211126 auto context = elementType.getContext ();
212127 auto attr = BlockTensorDescAttr::get (context, memory_space, array_length,
213128 boundary_check);
214- return Base::get (context, shape, elementType, attr, sg_map );
129+ return Base::get (context, shape, elementType, attr);
215130}
216131
217132TensorDescType TensorDescType::get (llvm::ArrayRef<int64_t > shape,
218133 mlir::Type elementType, int chunk_size,
219- MemorySpace memory_space,
220- mlir::Attribute sg_map) {
134+ MemorySpace memory_space) {
221135 auto context = elementType.getContext ();
222136 auto attr = ScatterTensorDescAttr::get (context, memory_space, chunk_size);
223- return Base::get (context, shape, elementType, attr, sg_map );
137+ return Base::get (context, shape, elementType, attr);
224138}
225139
226140} // namespace xegpu
0 commit comments