@@ -104,7 +104,7 @@ void StorageLayout::foreachField(
104104 callback) const {
105105 const auto lvlTypes = enc.getLvlTypes ();
106106 const Level lvlRank = enc.getLvlRank ();
107- SmallVector<COOSegment> cooSegs = SparseTensorType ( enc) .getCOOSegments ();
107+ SmallVector<COOSegment> cooSegs = enc.getCOOSegments ();
108108 FieldIndex fieldIdx = kDataFieldStartingIdx ;
109109
110110 ArrayRef cooSegsRef = cooSegs;
@@ -211,7 +211,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
211211 unsigned stride = 1 ;
212212 if (kind == SparseTensorFieldKind::CrdMemRef) {
213213 assert (lvl.has_value ());
214- const Level cooStart = SparseTensorType ( enc) .getAoSCOOStart ();
214+ const Level cooStart = enc.getAoSCOOStart ();
215215 const Level lvlRank = enc.getLvlRank ();
216216 if (lvl.value () >= cooStart && lvl.value () < lvlRank) {
217217 lvl = cooStart;
@@ -912,46 +912,53 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
912912 return emitError ()
913913 << " dimension-rank mismatch between encoding and tensor shape: "
914914 << getDimRank () << " != " << dimRank;
915+ if (auto expVal = getExplicitVal ()) {
916+ Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType ();
917+ if (attrType != elementType) {
918+ return emitError () << " explicit value type mismatch between encoding and "
919+ << " tensor element type: " << attrType
920+ << " != " << elementType;
921+ }
922+ }
923+ if (auto impVal = getImplicitVal ()) {
924+ Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType ();
925+ if (attrType != elementType) {
926+ return emitError () << " implicit value type mismatch between encoding and "
927+ << " tensor element type: " << attrType
928+ << " != " << elementType;
929+ }
930+ // Currently, we only support zero as the implicit value.
931+ auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
932+ auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
933+ auto impComplexVal = llvm::dyn_cast<complex ::NumberAttr>(impVal);
934+ if ((impFVal && impFVal.getValue ().isNonZero ()) ||
935+ (impIntVal && !impIntVal.getValue ().isZero ()) ||
936+ (impComplexVal && (impComplexVal.getImag ().isNonZero () ||
937+ impComplexVal.getReal ().isNonZero ()))) {
938+ return emitError () << " implicit value must be zero" ;
939+ }
940+ }
915941 return success ();
916942}
917943
918- // ===----------------------------------------------------------------------===//
919- // SparseTensorType Methods.
920- // ===----------------------------------------------------------------------===//
921-
922- bool mlir::sparse_tensor::SparseTensorType::isCOOType (Level startLvl,
923- bool isUnique) const {
924- if (!hasEncoding ())
925- return false ;
926- if (!isCompressedLvl (startLvl) && !isLooseCompressedLvl (startLvl))
927- return false ;
928- for (Level l = startLvl + 1 ; l < lvlRank; ++l)
929- if (!isSingletonLvl (l))
930- return false ;
931- // If isUnique is true, then make sure that the last level is unique,
932- // that is, when lvlRank == 1, the only compressed level is unique,
933- // and when lvlRank > 1, the last singleton is unique.
934- return !isUnique || isUniqueLvl (lvlRank - 1 );
935- }
936-
937- Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart () const {
944+ Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart () const {
938945 SmallVector<COOSegment> coo = getCOOSegments ();
939946 assert (coo.size () == 1 || coo.empty ());
940947 if (!coo.empty () && coo.front ().isAoS ()) {
941948 return coo.front ().lvlRange .first ;
942949 }
943- return lvlRank ;
950+ return getLvlRank () ;
944951}
945952
946953SmallVector<COOSegment>
947- mlir::sparse_tensor::SparseTensorType ::getCOOSegments () const {
954+ mlir::sparse_tensor::SparseTensorEncodingAttr ::getCOOSegments () const {
948955 SmallVector<COOSegment> ret;
949- if (! hasEncoding () || lvlRank <= 1 )
956+ if (getLvlRank () <= 1 )
950957 return ret;
951958
952959 ArrayRef<LevelType> lts = getLvlTypes ();
953960 Level l = 0 ;
954- while (l < lvlRank ) {
961+ while (l < getLvlRank () ) {
955962 auto lt = lts[l];
956963 if (lt.isa <LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
957964 auto cur = lts.begin () + l;
@@ -975,6 +982,25 @@ mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
975982 return ret;
976983}
977984
985+ // ===----------------------------------------------------------------------===//
986+ // SparseTensorType Methods.
987+ // ===----------------------------------------------------------------------===//
988+
989+ bool mlir::sparse_tensor::SparseTensorType::isCOOType (Level startLvl,
990+ bool isUnique) const {
991+ if (!hasEncoding ())
992+ return false ;
993+ if (!isCompressedLvl (startLvl) && !isLooseCompressedLvl (startLvl))
994+ return false ;
995+ for (Level l = startLvl + 1 ; l < lvlRank; ++l)
996+ if (!isSingletonLvl (l))
997+ return false ;
998+ // If isUnique is true, then make sure that the last level is unique,
999+ // that is, when lvlRank == 1, the only compressed level is unique,
1000+ // and when lvlRank > 1, the last singleton is unique.
1001+ return !isUnique || isUniqueLvl (lvlRank - 1 );
1002+ }
1003+
9781004RankedTensorType
9791005mlir::sparse_tensor::SparseTensorType::getCOOType (bool ordered) const {
9801006 SmallVector<LevelType> lvlTypes;
0 commit comments