@@ -715,150 +715,9 @@ MemRefType MemRefType::canonicalizeStridedLayout() {
715715 return MemRefType::Builder (*this ).setLayout ({});
716716}
717717
718- // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
719- // i.e. single term). Accumulate the AffineExpr into the existing one.
720- static void extractStridesFromTerm (AffineExpr e,
721- AffineExpr multiplicativeFactor,
722- MutableArrayRef<AffineExpr> strides,
723- AffineExpr &offset) {
724- if (auto dim = dyn_cast<AffineDimExpr>(e))
725- strides[dim.getPosition ()] =
726- strides[dim.getPosition ()] + multiplicativeFactor;
727- else
728- offset = offset + e * multiplicativeFactor;
729- }
730-
731- // / Takes a single AffineExpr `e` and populates the `strides` array with the
732- // / strides expressions for each dim position.
733- // / The convention is that the strides for dimensions d0, .. dn appear in
734- // / order to make indexing intuitive into the result.
735- static LogicalResult extractStrides (AffineExpr e,
736- AffineExpr multiplicativeFactor,
737- MutableArrayRef<AffineExpr> strides,
738- AffineExpr &offset) {
739- auto bin = dyn_cast<AffineBinaryOpExpr>(e);
740- if (!bin) {
741- extractStridesFromTerm (e, multiplicativeFactor, strides, offset);
742- return success ();
743- }
744-
745- if (bin.getKind () == AffineExprKind::CeilDiv ||
746- bin.getKind () == AffineExprKind::FloorDiv ||
747- bin.getKind () == AffineExprKind::Mod)
748- return failure ();
749-
750- if (bin.getKind () == AffineExprKind::Mul) {
751- auto dim = dyn_cast<AffineDimExpr>(bin.getLHS ());
752- if (dim) {
753- strides[dim.getPosition ()] =
754- strides[dim.getPosition ()] + bin.getRHS () * multiplicativeFactor;
755- return success ();
756- }
757- // LHS and RHS may both contain complex expressions of dims. Try one path
758- // and if it fails try the other. This is guaranteed to succeed because
759- // only one path may have a `dim`, otherwise this is not an AffineExpr in
760- // the first place.
761- if (bin.getLHS ().isSymbolicOrConstant ())
762- return extractStrides (bin.getRHS (), multiplicativeFactor * bin.getLHS (),
763- strides, offset);
764- return extractStrides (bin.getLHS (), multiplicativeFactor * bin.getRHS (),
765- strides, offset);
766- }
767-
768- if (bin.getKind () == AffineExprKind::Add) {
769- auto res1 =
770- extractStrides (bin.getLHS (), multiplicativeFactor, strides, offset);
771- auto res2 =
772- extractStrides (bin.getRHS (), multiplicativeFactor, strides, offset);
773- return success (succeeded (res1) && succeeded (res2));
774- }
775-
776- llvm_unreachable (" unexpected binary operation" );
777- }
778-
779- // / A stride specification is a list of integer values that are either static
780- // / or dynamic (encoded with ShapedType::kDynamic). Strides encode
781- // / the distance in the number of elements between successive entries along a
782- // / particular dimension.
783- // /
784- // / For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
785- // / non-contiguous memory region of `42` by `16` `f32` elements in which the
786- // / distance between two consecutive elements along the outer dimension is `1`
787- // / and the distance between two consecutive elements along the inner dimension
788- // / is `64`.
789- // /
790- // / The convention is that the strides for dimensions d0, .. dn appear in
791- // / order to make indexing intuitive into the result.
792- static LogicalResult getStridesAndOffset (MemRefType t,
793- SmallVectorImpl<AffineExpr> &strides,
794- AffineExpr &offset) {
795- AffineMap m = t.getLayout ().getAffineMap ();
796-
797- if (m.getNumResults () != 1 && !m.isIdentity ())
798- return failure ();
799-
800- auto zero = getAffineConstantExpr (0 , t.getContext ());
801- auto one = getAffineConstantExpr (1 , t.getContext ());
802- offset = zero;
803- strides.assign (t.getRank (), zero);
804-
805- // Canonical case for empty map.
806- if (m.isIdentity ()) {
807- // 0-D corner case, offset is already 0.
808- if (t.getRank () == 0 )
809- return success ();
810- auto stridedExpr =
811- makeCanonicalStridedLayoutExpr (t.getShape (), t.getContext ());
812- if (succeeded (extractStrides (stridedExpr, one, strides, offset)))
813- return success ();
814- assert (false && " unexpected failure: extract strides in canonical layout" );
815- }
816-
817- // Non-canonical case requires more work.
818- auto stridedExpr =
819- simplifyAffineExpr (m.getResult (0 ), m.getNumDims (), m.getNumSymbols ());
820- if (failed (extractStrides (stridedExpr, one, strides, offset))) {
821- offset = AffineExpr ();
822- strides.clear ();
823- return failure ();
824- }
825-
826- // Simplify results to allow folding to constants and simple checks.
827- unsigned numDims = m.getNumDims ();
828- unsigned numSymbols = m.getNumSymbols ();
829- offset = simplifyAffineExpr (offset, numDims, numSymbols);
830- for (auto &stride : strides)
831- stride = simplifyAffineExpr (stride, numDims, numSymbols);
832-
833- return success ();
834- }
835-
836718LogicalResult MemRefType::getStridesAndOffset (SmallVectorImpl<int64_t > &strides,
837719 int64_t &offset) {
838- // Happy path: the type uses the strided layout directly.
839- if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout ())) {
840- llvm::append_range (strides, strided.getStrides ());
841- offset = strided.getOffset ();
842- return success ();
843- }
844-
845- // Otherwise, defer to the affine fallback as layouts are supposed to be
846- // convertible to affine maps.
847- AffineExpr offsetExpr;
848- SmallVector<AffineExpr, 4 > strideExprs;
849- if (failed (::getStridesAndOffset (*this , strideExprs, offsetExpr)))
850- return failure ();
851- if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
852- offset = cst.getValue ();
853- else
854- offset = ShapedType::kDynamic ;
855- for (auto e : strideExprs) {
856- if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
857- strides.push_back (c.getValue ());
858- else
859- strides.push_back (ShapedType::kDynamic );
860- }
861- return success ();
720+ return getLayout ().getStridesAndOffset (getShape (), strides, offset);
862721}
863722
864723std::pair<SmallVector<int64_t >, int64_t > MemRefType::getStridesAndOffset () {
0 commit comments