@@ -270,33 +270,31 @@ LogicalResult LoadNdOp::verify() {
270270 if (!isReadHintOrNone (getL3HintAttr ()))
271271 return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
272272
273- // Handling a 1D vector as the result can be complex. It may represent the
274- // outcome of a 1D block load in SIMD mode or a fragment of a block load
275- // result in SIMT mode. In the latter case, the tensor descriptor must be
276- // evenly distributed, with each lane holding an equally sized fragment of
277- // the result. Only subgroup size 8 or 16 is supported.
278- if (valueTy. getRank () == 1 &&
279- valueTy.getNumElements () < tdescTy. getNumElements () ) {
273+ int tdescElems = tdescTy. getNumElements () * tdescTy. getArrayLength ();
274+ int valueElems = valueTy. getNumElements ();
275+
276+ // If the result vector is 1D and has less elements than the tensor
277+ // descriptor, it is supposed to be a SIMT op. The layout attribute in
278+ // tensor_desc is not needed.
279+ if (valueElems < tdescElems && valueTy.getRank () == 1 ) {
280280 // SIMT mode doesn't need LayoutAttr.
281281 if (tdescTy.getLayoutAttr ())
282282 return emitOpError ()
283283 << " TensorDesc doesn't need LayoutAttr for SIMT code" ;
284284
285- int tdescElems = tdescTy.getNumElements () * tdescTy.getArrayLength ();
286- int valueElems = valueTy.getNumElements ();
287-
288- int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1 ;
289- if (lanes != 16 && lanes != 8 ) {
285+ // For SIMT code, the load is evenly distributed across all lanes in a
286+ // subgroup. Since subgroup size is arch dependent, we only check even
287+ // distribution here.
288+ if (tdescElems % valueElems)
290289 return emitOpError ()
291290 << " Result shape " << makeString (getShapeOf (valueTy))
292291 << " is not a valid distribution for tensor descriptor "
293292 << tdescTy;
294- }
293+
295294 return success ();
296295 }
297296
298297 // Check SIMD mode.
299- auto array_len = tdescTy.getArrayLength ();
300298 // adjusted tensor descriptor shape tracks the expected shape of the result.
301299 auto tdescShape = getShapeOf (tdescTy);
302300 auto valueShape = getShapeOf (valueTy);
@@ -328,6 +326,7 @@ LogicalResult LoadNdOp::verify() {
328326 }
329327 }
330328
329+ auto array_len = tdescTy.getArrayLength ();
331330 if (array_len > 1 ) {
332331 tdescShape.insert (tdescShape.begin (), array_len);
333332 }
@@ -366,25 +365,23 @@ LogicalResult StoreNdOp::verify() {
366365 if (!isWriteHintOrNone (getL3HintAttr ()))
367366 return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
368367
369- auto tdescShape = getShapeOf (dstTy);
370- auto valueShape = getShapeOf (valTy);
368+ auto array_len = dstTy.getArrayLength ();
369+ if (array_len > 1 )
370+ return emitOpError (" array length is not supported by store_nd.\n " );
371+
372+ auto tdescElems = dstTy.getNumElements ();
373+ auto valueElems = valTy.getNumElements ();
371374
372- // Similar to LoadNdOp, handling a 1D vector as the value can be complex. It
373- // may represent the input of a 1D block store in SIMD mode or a fragment of
374- // a block store input in SIMT mode. In the latter case, the tensor descriptor
375- // must be evenly distributed, with each lane holding an equally sized
376- // fragment of the input. Only subgroup size 8 or 16 is supported.
377- if (valTy.getRank () == 1 && valTy.getNumElements () < dstTy.getNumElements ()) {
375+ // Similar to LoadNdOp, if the value vector is 1D and has less elements than
376+ // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
377+ // in tensor_desc is not needed.
378+ if (valTy.getRank () == 1 && valueElems < tdescElems) {
378379 // SIMT mode doesn't need LayoutAttr.
379380 if (dstTy.getLayoutAttr ())
380381 return emitOpError ()
381382 << " TensorDesc doesn't need LayoutAttr for SIMT code" ;
382383
383- int tdescElems = dstTy.getNumElements () * dstTy.getArrayLength ();
384- int valueElems = valueShape[0 ];
385-
386- int lanes = tdescElems % valueElems == 0 ? tdescElems / valueElems : -1 ;
387- if (lanes != 16 && lanes != 8 ) {
384+ if (tdescElems % valueElems) {
388385 return emitOpError ()
389386 << " Value shape " << makeString (getShapeOf (valTy))
390387 << " is not a valid distribution for tensor descriptor " << dstTy;
@@ -393,6 +390,8 @@ LogicalResult StoreNdOp::verify() {
393390 }
394391
395392 // SIMD code should have the same shape as the tensor descriptor.
393+ auto tdescShape = getShapeOf (dstTy);
394+ auto valueShape = getShapeOf (valTy);
396395 if (tdescShape != valueShape) {
397396 return emitOpError () << " Value shape " << makeString (valueShape)
398397 << " is not consistent with tensor descriptor "
0 commit comments