@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
175175 if (parser.parseGreater ())
176176 return {};
177177
178- return TensorDescType::get (parser.getContext (), shape, elementType,
179- encoding.value_or (mlir::Attribute ()),
180- sg_map.value_or (mlir::Attribute ()));
178+ return TensorDescType::getChecked (
179+ [&]() { return parser.emitError (parser.getNameLoc ()); },
180+ parser.getContext (), shape, elementType,
181+ encoding.value_or (mlir::Attribute ()), sg_map.value_or (mlir::Attribute ()));
181182}
182183
183184void TensorDescType::print (::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,81 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
223224 return Base::get (context, shape, elementType, attr, sg_map);
224225}
225226
227+ LogicalResult TensorDescType::verify (
228+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
229+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
230+ mlir::Attribute encoding, mlir::Attribute sg_map) {
231+ size_t rank = shape.size ();
232+ if (rank != 1 && rank != 2 )
233+ return emitError () << " expected 1D or 2D tensor" ;
234+
235+ auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
236+ if (scatterAttr) {
237+ // Expected tensor ranks for scattered data:
238+ // - 1D tensor for fully non-contiguous elements (chunk size == 1)
239+ // - 2D tensor for scattered blocks (chunk size > 1)
240+ IntegerAttr chunkAttr = scatterAttr.getChunkSize ();
241+ unsigned chunkSize = chunkAttr ? chunkAttr.getInt () : 1 ;
242+ if (rank == 1 && chunkSize != 1 )
243+ return emitError () << " expected non-contiguous elements for 1D tensor" ;
244+ if (rank == 2 && chunkSize < 2 )
245+ return emitError () << " expected chunk blocks for 2D tensor" ;
246+ }
247+
248+ if (auto blockAttr =
249+ mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
250+ MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace ();
251+ if (rank == 2 && memorySpaceAttr &&
252+ memorySpaceAttr.getValue () == MemorySpace::SLM)
253+ return emitError () << " SLM is not supported for 2D block tensor" ;
254+ }
255+
256+ if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
257+ ArrayRef<uint32_t > wiLayout = sgMapAttr.getWiLayout ();
258+ ArrayRef<uint32_t > wiData = sgMapAttr.getWiData ();
259+
260+ if (rank == 1 ) {
261+ if (wiLayout[0 ] != 1 || wiData[0 ] != 1 )
262+ return emitError ()
263+ << " outer layout distribution and data mapping must be 1 "
264+ " for 1D tensor" ;
265+ }
266+
267+ if (scatterAttr) {
268+ // Validate subgroup mapping rules for scattered tensors.
269+ // A work-item's slice of the tensor with shape [sg_size] or
270+ // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
271+ // the mapping should reflect that.
272+ if (wiData[0 ] != 1 )
273+ return emitError ()
274+ << " cannot map over non-contiguous scattered row elements" ;
275+
276+ IntegerAttr chunkAttr = scatterAttr.getChunkSize ();
277+ unsigned chunkSize = chunkAttr ? chunkAttr.getInt () : 1 ;
278+ if (wiData[1 ] != chunkSize)
279+ return emitError () << " work item data mapping must match the number of "
280+ " contiguous elements" ;
281+ }
282+
283+ // For 1D tensor, pad the shape with an outer unit dimension to allow common
284+ // validation logic.
285+ SmallVector<int64_t > tensorShape (shape.begin (), shape.end ());
286+ if (rank == 1 )
287+ tensorShape = {1 , tensorShape.back ()};
288+
289+ size_t dims = tensorShape.size ();
290+ for (size_t i = 0 ; i < dims; ++i) {
291+ uint32_t numElemPerWi = wiLayout[i] * wiData[i];
292+ if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0 )
293+ return emitError () << " cannot distribute " << tensorShape[i] << " over "
294+ << wiLayout[i] << " work items with " << wiData[i]
295+ << " elements each" ;
296+ }
297+ }
298+
299+ return success ();
300+ }
301+
226302} // namespace xegpu
227303} // namespace mlir
228304
0 commit comments