@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
175175 if (parser.parseGreater ())
176176 return {};
177177
178- return TensorDescType::get (parser.getContext (), shape, elementType,
179- encoding.value_or (mlir::Attribute ()),
180- sg_map.value_or (mlir::Attribute ()));
178+ return TensorDescType::getChecked (
179+ [&]() { return parser.emitError (parser.getNameLoc ()); },
180+ parser.getContext (), shape, elementType,
181+ encoding.value_or (mlir::Attribute ()), sg_map.value_or (mlir::Attribute ()));
181182}
182183
183184void TensorDescType::print (::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
223224 return Base::get (context, shape, elementType, attr, sg_map);
224225}
225226
227+ LogicalResult TensorDescType::verify (
228+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
229+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
230+ mlir::Attribute encoding, mlir::Attribute sg_map) {
231+ size_t rank = shape.size ();
232+ if (rank > 2 )
233+ return emitError () << " desc shape rank exceeds 2" ;
234+
235+ if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
236+ ArrayRef<uint32_t > wiLayout = sgMapAttr.getWiLayout ();
237+ ArrayRef<uint32_t > wiData = sgMapAttr.getWiData ();
238+
239+ if (rank == 1 ) {
240+ if (wiLayout[0 ] != 1 || wiData[0 ] != 1 )
241+ return emitError () << " outer layout and data mapping must be 1 "
242+ " for 1D tensor" ;
243+ }
244+
245+ // For 1D tensor, pad the shape with an outer unit dimension to allow common
246+ // validation logic.
247+ SmallVector<int64_t > tensorShape (shape.begin (), shape.end ());
248+ if (rank == 1 )
249+ tensorShape = {1 , tensorShape.back ()};
250+
251+ size_t dims = tensorShape.size ();
252+ for (size_t i = 0 ; i < dims; ++i) {
253+ uint32_t numElemPerWi = wiLayout[i] * wiData[i];
254+ if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0 )
255+ return emitError () << " cannot map " << tensorShape[i]
256+ << " elements into " << wiLayout[i] << " by "
257+ << wiData[i] << " tiles" ;
258+ }
259+
260+ if (llvm::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
261+ auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
262+ if (wiData[0 ] != 1 )
263+ return emitError ()
264+ << " cannot map over non-contiguous scattered elements" ;
265+
266+ unsigned chunkSize = scatterAttr.getChunkSize ().getInt ();
267+ if (wiData[1 ] > chunkSize)
268+ return emitError ()
269+ << " too few contiguous elements for work item mapping" ;
270+ }
271+ }
272+
273+ return success ();
274+ }
275+
226276} // namespace xegpu
227277} // namespace mlir
228278
0 commit comments