@@ -55,6 +55,77 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
5555 return Base::get (context, scopeAttr, chunkSizeAttr);
5656}
5757
58+ // ===----------------------------------------------------------------------===//
59+ // XeGPU_SGMapAttr
60+ // ===----------------------------------------------------------------------===//
61+ namespace {
62+ template <typename T, unsigned N>
63+ LogicalResult parseIntArrayField (::mlir::AsmParser &parser,
64+ llvm::SmallVector<T, N> &result,
65+ llvm::StringRef fieldName) {
66+ if (failed (parser.parseKeyword (fieldName))) {
67+ parser.emitError (parser.getCurrentLocation (),
68+ " unexpected field name. Expected " + fieldName + " ." );
69+ return failure ();
70+ }
71+
72+ if (failed (parser.parseEqual ())) {
73+ parser.emitError (parser.getCurrentLocation (), " expected '=' sign." );
74+ return failure ();
75+ }
76+
77+ auto elemParser = [&]() -> llvm::ParseResult {
78+ uint32_t elem = 0 ;
79+ auto res = parser.parseInteger (elem);
80+ result.push_back (elem);
81+ return res;
82+ };
83+
84+ return parser.parseCommaSeparatedList (AsmParser::Delimiter::Square,
85+ elemParser, fieldName);
86+ }
87+ } // namespace
88+
89+ mlir::Attribute SGMapAttr::parse (::mlir::AsmParser &parser,
90+ ::mlir::Type attrType) {
91+ if (failed (parser.parseLess ()))
92+ return {};
93+
94+ llvm::SmallVector<uint32_t , 2 > wi_layout, wi_data;
95+ if (failed (parseIntArrayField (parser, wi_layout, " wi_layout" )))
96+ return {};
97+
98+ if (failed (parser.parseComma ()))
99+ return {};
100+
101+ if (failed (parseIntArrayField (parser, wi_data, " wi_data" )))
102+ return {};
103+
104+ return SGMapAttr::getChecked (
105+ [&]() { return parser.emitError (parser.getNameLoc ()); },
106+ parser.getContext (), wi_layout, wi_data);
107+ }
108+
109+ void SGMapAttr::print (::mlir::AsmPrinter &printer) const {
110+ printer << " <" ;
111+ printer.printKeywordOrString (" wi_layout" );
112+ printer << " = [" << getWiLayout () << " ], " ;
113+ printer.printKeywordOrString (" wi_data" );
114+ printer << " = [" << getWiData () << " ]" ;
115+ printer << " >" ;
116+ }
117+
118+ LogicalResult
119+ SGMapAttr::verify (llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
120+ llvm::ArrayRef<uint32_t > wi_layout,
121+ llvm::ArrayRef<uint32_t > wi_data) {
122+ if (wi_layout.size () != 2 )
123+ return emitError () << " expected wi_layout of size 2" ;
124+ if (wi_data.size () != 2 )
125+ return emitError () << " expected wi_data of size 2" ;
126+ return success ();
127+ }
128+
58129// ===----------------------------------------------------------------------===//
59130// XeGPU_TensorDescType
60131// ===----------------------------------------------------------------------===//
@@ -63,6 +134,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
63134 llvm::SmallVector<int64_t > shape;
64135 mlir::Type elementType;
65136 mlir::FailureOr<mlir::Attribute> encoding;
137+ mlir::FailureOr<mlir::Attribute> sg_map;
66138
67139 // Parse literal '<'
68140 if (parser.parseLess ())
@@ -81,22 +153,31 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
81153 }
82154
83155 // parse optional attributes
84- if (mlir::succeeded (parser.parseOptionalComma ())) {
85- encoding = mlir::FieldParser<mlir::Attribute>::parse (parser);
86- if (mlir::failed (encoding)) {
87- parser.emitError (
88- parser.getCurrentLocation (),
89- " Failed to parse the attribute field for TensorDescType.\n " );
90- return {};
156+ while (mlir::succeeded (parser.parseOptionalComma ())) {
157+ mlir::Attribute attr;
158+ ParseResult res = parser.parseAttribute (attr);
159+ if (mlir::succeeded (res)) {
160+ if (mlir::isa<SGMapAttr>(attr)) {
161+ sg_map = attr;
162+ continue ;
163+ }
164+ if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
165+ encoding = attr;
166+ continue ;
167+ }
91168 }
169+ parser.emitError (parser.getCurrentLocation (),
170+ " Failed to parse the attribute.\n " );
171+ return {};
92172 }
93173
94174 // Parse literal '>'
95175 if (parser.parseGreater ())
96176 return {};
97177
98178 return TensorDescType::get (parser.getContext (), shape, elementType,
99- encoding.value_or (mlir::Attribute ()));
179+ encoding.value_or (mlir::Attribute ()),
180+ sg_map.value_or (mlir::Attribute ()));
100181}
101182
102183void TensorDescType::print (::mlir::AsmPrinter &printer) const {
@@ -116,25 +197,30 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
116197 if (auto encoding = getEncoding ())
117198 printer << " , " << encoding;
118199
200+ if (auto sg_map = getSgMap ())
201+ printer << " , " << sg_map;
202+
119203 printer << " >" ;
120204}
121205
122206TensorDescType TensorDescType::get (llvm::ArrayRef<int64_t > shape,
123207 mlir::Type elementType, int array_length,
124208 bool boundary_check,
125- MemorySpace memory_space) {
209+ MemorySpace memory_space,
210+ mlir::Attribute sg_map) {
126211 auto context = elementType.getContext ();
127212 auto attr = BlockTensorDescAttr::get (context, memory_space, array_length,
128213 boundary_check);
129- return Base::get (context, shape, elementType, attr);
214+ return Base::get (context, shape, elementType, attr, sg_map );
130215}
131216
132217TensorDescType TensorDescType::get (llvm::ArrayRef<int64_t > shape,
133218 mlir::Type elementType, int chunk_size,
134- MemorySpace memory_space) {
219+ MemorySpace memory_space,
220+ mlir::Attribute sg_map) {
135221 auto context = elementType.getContext ();
136222 auto attr = ScatterTensorDescAttr::get (context, memory_space, chunk_size);
137- return Base::get (context, shape, elementType, attr);
223+ return Base::get (context, shape, elementType, attr, sg_map );
138224}
139225
140226} // namespace xegpu
0 commit comments