@@ -541,9 +541,9 @@ LogicalResult LoadGatherOp::verify() {
541541 if (tdescShape[0 ] != maskShape[0 ])
542542 return emitOpError (" dim-0 of the Mask and TensorDesc should be the same." );
543543
544+ auto chunkSize = tdescTy.getChunkSize ();
544545 // for SIMT code, the value should be 1D vector with size of chunkSize.
545546 if (valueTy.getRank () == 1 && valueTy.getNumElements () != tdescShape[0 ]) {
546- auto chunkSize = tdescTy.getChunkSize ();
547547 if (valueTy.getNumElements () != chunkSize) {
548548 return emitOpError ()
549549 << " Result shape " << makeString (valueShape)
@@ -557,6 +557,11 @@ LogicalResult LoadGatherOp::verify() {
557557 return emitOpError () << " doesn't need TransposeAttr for SIMT code" ;
558558 }
559559 return success ();
560+ } else if (valueTy.getRank () == 1 && tdescShape[0 ] == chunkSize) {
561+ // for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
562+ // it is a valid SIMT code if chunkSize happens to be the same as
563+ // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
564+ return success ();
560565 }
561566
562567 // For SIMD code verification.
@@ -602,9 +607,9 @@ LogicalResult StoreScatterOp::verify() {
602607 if (tdescShape[0 ] != maskShape[0 ])
603608 return emitOpError (" dim-0 of the Mask and TensorDesc should be the same." );
604609
610+ auto chunkSize = tdescTy.getChunkSize ();
605611 // for SIMT code, the value should be 1D vector with size of chunkSize.
606612 if (valueTy.getRank () == 1 && valueTy.getNumElements () != tdescShape[0 ]) {
607- auto chunkSize = tdescTy.getChunkSize ();
608613 if (valueTy.getNumElements () != chunkSize) {
609614 return emitOpError ()
610615 << " Value shape " << makeString (valueShape)
@@ -618,6 +623,11 @@ LogicalResult StoreScatterOp::verify() {
618623 return emitOpError () << " doesn't need TransposeAttr for SIMT code" ;
619624 }
620625 return success ();
626+ } else if (valueTy.getRank () == 1 && tdescShape[0 ] == chunkSize) {
627+ // for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
628+ // it is a valid SIMT code if chunkSize happens to be the same as
629+ // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
630+ return success ();
621631 }
622632
623633 // for SIMD code verification.
0 commit comments