@@ -760,6 +760,9 @@ LogicalResult MakeDmaDescriptorOp::verify() {
760760 if (rank < 2 ) {
761761 return emitOpError (" tensor and tile must be at least of rank 2." );
762762 }
763+ if (rank > 5 ) {
764+ return emitOpError (" tensor and tile must be at most of rank 5." );
765+ }
763766 if (rank != globalStaticStrides.size ()) {
764767 return emitOpError (" strides and sizes must have same rank." );
765768 }
@@ -779,6 +782,82 @@ LogicalResult MakeDmaDescriptorOp::verify() {
779782 return success ();
780783}
781784
785+ static bool maybeUpdateDynamicIndexList (
786+ ArrayRef<int64_t > staticElements, ArrayRef<Attribute> foldedElements,
787+ SmallVector<Value> dynamicElements, SmallVector<int64_t > &newStaticElements,
788+ SmallVector<Value> &newDynamicElements) {
789+ bool changed = false ;
790+ int index = 0 ;
791+
792+ for (int64_t static_element : staticElements) {
793+ if (!ShapedType::isDynamic (static_element)) {
794+ newStaticElements.push_back (static_element);
795+ continue ;
796+ }
797+
798+ Attribute folded_element = foldedElements[index++];
799+ if (auto attr = dyn_cast<IntegerAttr>(folded_element)) {
800+ newStaticElements.push_back (attr.getInt ());
801+ changed = true ;
802+ continue ;
803+ }
804+
805+ newStaticElements.push_back (ShapedType::kDynamic );
806+ newDynamicElements.push_back (dynamicElements[index]);
807+ }
808+ return changed;
809+ }
810+
811+ OpFoldResult MakeDmaDescriptorOp::fold (FoldAdaptor adaptor) {
812+ ArrayRef<int64_t > oldGlobalStaticStrides = adaptor.getGlobalStaticStrides ();
813+ ArrayRef<Attribute> foldedGlobalDynamicStrides =
814+ adaptor.getGlobalDynamicStrides ();
815+ SmallVector<Value> oldGlobalDynamicStrides = getGlobalDynamicStrides ();
816+
817+ SmallVector<int64_t > newGlobalStaticStrides;
818+ SmallVector<Value> newGlobalDynamicStrides;
819+
820+ bool change = maybeUpdateDynamicIndexList (
821+ oldGlobalStaticStrides, foldedGlobalDynamicStrides,
822+ oldGlobalDynamicStrides, newGlobalStaticStrides, newGlobalDynamicStrides);
823+
824+ ArrayRef<int64_t > oldGlobalStaticSizes = adaptor.getGlobalStaticSizes ();
825+ ArrayRef<Attribute> foldedGlobalDynamicSizes =
826+ adaptor.getGlobalDynamicSizes ();
827+ SmallVector<Value> oldGlobalDynamicSizes = getGlobalDynamicSizes ();
828+
829+ SmallVector<int64_t > newGlobalStaticSizes;
830+ SmallVector<Value> newGlobalDynamicSizes;
831+
832+ change |= maybeUpdateDynamicIndexList (
833+ oldGlobalStaticSizes, foldedGlobalDynamicSizes, oldGlobalDynamicSizes,
834+ newGlobalStaticSizes, newGlobalDynamicSizes);
835+
836+ ArrayRef<int64_t > oldSharedStaticSizes = adaptor.getSharedStaticSizes ();
837+ ArrayRef<Attribute> foldedSharedDynamicSizes =
838+ adaptor.getSharedDynamicSizes ();
839+ SmallVector<Value> oldSharedDynamicSizes = getSharedDynamicSizes ();
840+
841+ SmallVector<int64_t > newSharedStaticSizes;
842+ SmallVector<Value> newSharedDynamicSizes;
843+
844+ change |= maybeUpdateDynamicIndexList (
845+ oldSharedStaticSizes, foldedSharedDynamicSizes, oldSharedDynamicSizes,
846+ newSharedStaticSizes, newSharedDynamicSizes);
847+
848+ if (change) {
849+ setGlobalStaticStrides (newGlobalStaticStrides);
850+ getGlobalDynamicStridesMutable ().assign (newGlobalDynamicStrides);
851+ setGlobalStaticSizes (newGlobalStaticSizes);
852+ getGlobalDynamicSizesMutable ().assign (newGlobalDynamicSizes);
853+ setSharedStaticSizes (newSharedStaticSizes);
854+ getSharedDynamicSizesMutable ().assign (newSharedDynamicSizes);
855+ return getResult ();
856+ }
857+
858+ return nullptr ;
859+ }
860+
782861// ===----------------------------------------------------------------------===//
783862// ScaledMFMAOp
784863// ===----------------------------------------------------------------------===//
0 commit comments