@@ -293,9 +293,8 @@ Type SparseTensorEncodingAttr::getCrdType() const {
293293SparseTensorEncodingAttr
294294SparseTensorEncodingAttr::withDimToLvl (AffineMap dimToLvl) const {
295295 assert (getImpl () && " Uninitialized SparseTensorEncodingAttr" );
296- // TODO: infer lvlToDim
297296 return SparseTensorEncodingAttr::get (getContext (), getLvlTypes (), dimToLvl,
298- /* lvlToDim */ AffineMap (), getPosWidth (),
297+ getLvlToDim (), getPosWidth (),
299298 getCrdWidth ());
300299}
301300
@@ -583,7 +582,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
583582#undef RETURN_ON_FAIL
584583
585584 // Construct struct-like storage for attribute.
586- AffineMap lvlToDim; // TODO: infer
585+ // TODO: Fetch lvlToDim if user provides one
586+ AffineMap lvlToDim = inferLvlToDim (dimToLvl, parser.getContext ());
587587 return parser.getChecked <SparseTensorEncodingAttr>(
588588 parser.getContext (), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
589589 dimSlices);
@@ -749,6 +749,75 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
749749 return nullptr ;
750750}
751751
752+ AffineMap mlir::sparse_tensor::inferLvlToDim (AffineMap dimToLvl,
753+ MLIRContext *context) {
754+ auto map = static_cast <AffineMap>(dimToLvl);
755+ AffineMap lvlToDim;
756+ // Return an empty lvlToDim when inference is not successful.
757+ if (!map || map.getNumSymbols () != 0 ) {
758+ lvlToDim = AffineMap ();
759+ } else if (map.isPermutation ()) {
760+ lvlToDim = inversePermutation (map);
761+ } else {
762+ // TODO: check if it's block sparsity
763+ lvlToDim = inverseBlockSparsity (map, context);
764+ }
765+ return lvlToDim;
766+ }
767+
768+ AffineMap mlir::sparse_tensor::inverseBlockSparsity (AffineMap dimToLvl,
769+ MLIRContext *context) {
770+ SmallVector<AffineExpr> lvlExprs;
771+ auto numLvls = dimToLvl.getNumResults ();
772+ lvlExprs.reserve (numLvls);
773+ // lvlExprComponents stores information of the floordiv and mod operations
774+ // applied to the same dimension, so as to build the lvlToDim map.
775+ std::map<unsigned , SmallVector<AffineExpr, 3 >> lvlExprComponents;
776+ for (unsigned i = 0 , n = numLvls; i < n; i++) {
777+ auto result = dimToLvl.getResult (i);
778+ if (auto binOp = result.dyn_cast <AffineBinaryOpExpr>()) {
779+ if (result.getKind () == AffineExprKind::FloorDiv) {
780+ // Position of the dimension in dimToLvl.
781+ auto pos = binOp.getLHS ().dyn_cast <AffineDimExpr>().getPosition ();
782+ assert (lvlExprComponents.find (pos) == lvlExprComponents.end () &&
783+ " expected only one floordiv for each dimension" );
784+ SmallVector<AffineExpr, 3 > components;
785+ // Level variable for floordiv.
786+ components.push_back (getAffineDimExpr (i, context));
787+ // Multiplier.
788+ components.push_back (binOp.getRHS ());
789+ // Map key is the position of the dimension.
790+ lvlExprComponents[pos] = components;
791+ } else if (result.getKind () == AffineExprKind::Mod) {
792+ auto pos = binOp.getLHS ().dyn_cast <AffineDimExpr>().getPosition ();
793+ assert (lvlExprComponents.find (pos) != lvlExprComponents.end () &&
794+ " expected floordiv before mod" );
795+ // Add level variable for mod to the same vector
796+ // of the corresponding floordiv.
797+ lvlExprComponents[pos].push_back (getAffineDimExpr (i, context));
798+ } else {
799+ assert (false && " expected floordiv or mod" );
800+ }
801+ } else {
802+ lvlExprs.push_back (getAffineDimExpr (i, context));
803+ }
804+ }
805+ // Build lvlExprs from lvlExprComponents.
806+ // For example, for il = i floordiv 2 and ii = i mod 2, the components
807+ // would be [il, 2, ii]. It could be used to build the AffineExpr
808+ // i = il * 2 + ii in lvlToDim.
809+ for (auto &components : lvlExprComponents) {
810+ assert (components.second .size () == 3 &&
811+ " expected 3 components to build lvlExprs" );
812+ auto mulOp = getAffineBinaryOpExpr (
813+ AffineExprKind::Mul, components.second [0 ], components.second [1 ]);
814+ auto addOp =
815+ getAffineBinaryOpExpr (AffineExprKind::Add, mulOp, components.second [2 ]);
816+ lvlExprs.push_back (addOp);
817+ }
818+ return dimToLvl.get (dimToLvl.getNumResults (), 0 , lvlExprs, context);
819+ }
820+
752821bool mlir::sparse_tensor::isCOOType (SparseTensorEncodingAttr enc,
753822 Level startLvl, bool isUnique) {
754823 if (!enc ||
@@ -811,7 +880,7 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
811880 // default value.
812881 unsigned posWidth = src.getPosWidth ();
813882 unsigned crdWidth = src.getCrdWidth ();
814- AffineMap invPerm; // TODO
883+ AffineMap invPerm = src. getLvlToDim ();
815884 auto enc = SparseTensorEncodingAttr::get (src.getContext (), lvlTypes, lvlPerm,
816885 invPerm, posWidth, crdWidth);
817886 return RankedTensorType::get (src.getDimShape (), src.getElementType (), enc);
0 commit comments