@@ -547,38 +547,27 @@ LogicalResult LoadGatherOp::verify() {
547547 return emitOpError (" dim-0 of the Mask and TensorDesc should be the same." );
548548
549549 auto chunkSize = tdescTy.getChunkSize ();
550- // for SIMT code, the value should be 1D vector with size of chunkSize.
551- if (valueTy.getRank () == 1 && valueTy.getNumElements () != tdescShape[0 ]) {
552- if (valueTy.getNumElements () != chunkSize) {
550+
551+ // a valid shape for SIMT case
552+ if (valueTy.getRank () == 1 && valueTy.getNumElements () == chunkSize) {
553+ if (tdescTy.getLayoutAttr ())
553554 return emitOpError ()
554- << " Result shape " << makeString (valueShape)
555- << " is not a valid distribution for tensor descriptor "
556- << tdescTy;
557- } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
558- if (tdescTy.getLayoutAttr ())
559- return emitOpError ()
560- << " TensorDesc doesn't need LayoutAttr for SIMT code" ;
561- if (getTransposeAttr ())
562- return emitOpError () << " doesn't need TransposeAttr for SIMT code" ;
563- }
564- return success ();
565- } else if (valueTy.getRank () == 1 && tdescShape[0 ] == chunkSize) {
566- // for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
567- // it is a valid SIMT code if chunkSize happens to be the same as
568- // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
555+ << " TensorDesc doesn't need LayoutAttr for SIMT code" ;
556+ if (getTransposeAttr ())
557+ return emitOpError () << " doesn't need TransposeAttr for SIMT code" ;
569558 return success ();
570559 }
571560
572- // For SIMD code verification.
573- if (tdescTy.getRank () == 2 ) {
561+ if (tdescTy.getRank () == 2 && valueTy.getRank () == 2 ) {
574562 if (!getTransposeAttr ())
575563 return emitOpError (" load of rank-2 tensor has to be transposed." );
576564 transpose ({1 , 0 }, tdescShape);
577565 }
578566
579567 if (tdescShape != valueShape)
580568 return emitOpError () << " Result shape " << makeString (valueShape)
581- << " is not consistent with tensor descriptor "
569+ << " is neither a valid distribution for SIMT nor "
570+ " consistent with the tensor descriptor for SIMD "
582571 << tdescTy;
583572 return success ();
584573}
@@ -613,38 +602,27 @@ LogicalResult StoreScatterOp::verify() {
613602 return emitOpError (" dim-0 of the Mask and TensorDesc should be the same." );
614603
615604 auto chunkSize = tdescTy.getChunkSize ();
616- // for SIMT code, the value should be 1D vector with size of chunkSize.
617- if (valueTy.getRank () == 1 && valueTy.getNumElements () != tdescShape[0 ]) {
618- if (valueTy.getNumElements () != chunkSize) {
605+
606+ // a valid shape for SIMT case
607+ if (valueTy.getRank () == 1 && valueTy.getNumElements () == chunkSize) {
608+ if (tdescTy.getLayoutAttr ())
619609 return emitOpError ()
620- << " Value shape " << makeString (valueShape)
621- << " is not a valid distribution for tensor descriptor "
622- << tdescTy;
623- } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
624- if (tdescTy.getLayoutAttr ())
625- return emitOpError ()
626- << " TensorDesc doesn't need LayoutAttr for SIMT code" ;
627- if (getTransposeAttr ())
628- return emitOpError () << " doesn't need TransposeAttr for SIMT code" ;
629- }
630- return success ();
631- } else if (valueTy.getRank () == 1 && tdescShape[0 ] == chunkSize) {
632- // for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
633- // it is a valid SIMT code if chunkSize happens to be the same as
634- // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
610+ << " TensorDesc doesn't need LayoutAttr for SIMT code" ;
611+ if (getTransposeAttr ())
612+ return emitOpError () << " doesn't need TransposeAttr for SIMT code" ;
635613 return success ();
636614 }
637615
638- // for SIMD code verification.
639- if (tdescTy.getRank () == 2 ) {
616+ if (tdescTy.getRank () == 2 && valueTy.getRank () == 2 ) {
640617 if (!getTransposeAttr ())
641618 return emitOpError (" Store of a rank-2 tensor has to be transposed." );
642619 transpose ({1 , 0 }, tdescShape);
643620 }
644621
645622 if (tdescShape != valueShape)
646623 return emitOpError () << " Value shape " << makeString (valueShape)
647- << " is not consistent with tensor descriptor "
624+ << " is neither a valid distribution for SIMT nor "
625+ " consistent with the tensor descriptor for SIMD "
648626 << tdescTy;
649627
650628 return success ();
0 commit comments