@@ -439,14 +439,20 @@ MemDescTransOp::inferReturnTypes(MLIRContext *context,
439439 return failure ();
440440 }
441441 }
442+
443+ // Permute the last `rank` dims of the source alloc shape.
444+ SmallVector<int64_t > allocShape =
445+ applyPermutation (argTy.getAllocShape ().take_back (order.size ()), order);
446+ allocShape.insert (allocShape.begin (), argTy.getAllocShape ().begin (),
447+ argTy.getAllocShape ().end () - order.size ());
448+
442449 inferredReturnTypes.push_back (
443450 MemDescType::get (retShape, retEltTy, retEncoding, argTy.getMemorySpace (),
444- argTy.getMutableMemory ()));
451+ argTy.getMutableMemory (), allocShape ));
445452 return success ();
446453}
447454
448455// MemDescReshapeOp
449-
450456LogicalResult MemDescReshapeOp::verify () {
451457 MemDescType dstType = getResult ().getType ();
452458 MemDescType srcType = getSrc ().getType ();
@@ -472,6 +478,13 @@ LogicalResult MemDescReshapeOp::verify() {
472478 return success ();
473479}
474480
481+ // MemDescReinterpretOp
482+ LogicalResult MemDescReinterpretOp::verify () {
483+ if (getSrc ().getType ().getMemorySpace () != getType ().getMemorySpace ())
484+ return emitError (" source and destination memory space must match" );
485+ return success ();
486+ }
487+
475488// LocalAllocOp
476489void LocalAllocOp::getEffects (
477490 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
@@ -623,20 +636,15 @@ LogicalResult MemDescSubviewOp::verify() {
623636 " only nD -> (n-1)D rank-reducing subviews are supported" );
624637 }
625638 for (auto offset : getOffsets ().take_back (dstTy.getRank ())) {
626- if (auto constOp = offset.getDefiningOp <arith::ConstantOp>()) {
627- if (auto offsetInt = dyn_cast<IntegerAttr>(constOp.getValue ())) {
628- if (offsetInt.getInt () != 0 ) {
629- return emitError (" only first offset can be non-zero for a "
630- " rank-reducing subview" );
631- }
632- } else {
633- return emitError (
634- " only integer constant values are allowed for the split" );
635- }
636- } else {
639+ APInt value;
640+ if (!matchPattern (offset, m_ConstantInt (&value))) {
637641 return emitError (" only constant values are allowed outside the front "
638642 " dimension in a rank-reducing subview" );
639643 }
644+ if (!value.isZero ()) {
645+ return emitError (
646+ " only first offset can be non-zero for a rank-reducing subview" );
647+ }
640648 }
641649 return success ();
642650 }
@@ -658,16 +666,10 @@ LogicalResult MemDescSubviewOp::verify() {
658666 }
659667 SmallVector<int64_t > offsets;
660668 for (auto offset : getOffsets ()) {
661- if (auto constOp = offset.getDefiningOp <arith::ConstantOp>()) {
662- if (auto offsetInt = dyn_cast<IntegerAttr>(constOp.getValue ())) {
663- offsets.push_back (offsetInt.getInt ());
664- } else {
665- return emitError (
666- " only integer constant values are allowed for the split" );
667- }
668- } else {
669+ APInt value;
670+ if (!matchPattern (offset, m_ConstantInt (&value)))
669671 return emitError (" only constant values are allowed for the split" );
670- }
672+ offsets. push_back (value. getSExtValue ());
671673 }
672674 // Identity subview
673675 if (dim == -1 ) {
0 commit comments