7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " mlir/Dialect/Tensor/IR/Tensor.h"
10
+ #include " mlir/IR/DialectImplementation.h"
10
11
#include " mlir/Transforms/InliningUtils.h"
12
+ #include " llvm/ADT/TypeSwitch.h"
11
13
12
14
using namespace mlir ;
13
15
using namespace mlir ::tensor;
14
16
17
+ // ===----------------------------------------------------------------------===//
18
+ // TableGen'd Attributes Methods
19
+ // ===----------------------------------------------------------------------===//
20
+
21
+ #define GET_ATTRDEF_CLASSES
22
+ #include " mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
23
+
24
+ // Dictionary keys.
25
+ static constexpr StringRef getSparseDimLevelTypeAttrName () {
26
+ return " sparseDimLevelType" ;
27
+ }
28
+ static constexpr StringRef getSparseDimOrderingAttrName () {
29
+ return " sparseDimOrdering" ;
30
+ }
31
+ static constexpr StringRef getSparsePointerBitWidthAttrName () {
32
+ return " sparsePointerBitWidth" ;
33
+ }
34
+ static constexpr StringRef getSparseIndexBitWidthAttrName () {
35
+ return " sparseIndexBitWidth" ;
36
+ }
37
+
38
+ // Dictionary values.
39
+ static constexpr StringRef getDenseDimLevelTypeVal () { return " dense" ; }
40
+ static constexpr StringRef getCompressedDimLevelTypeVal () {
41
+ return " compressed" ;
42
+ }
43
+ static constexpr StringRef getSingletonDimLevelTypeVal () { return " singleton" ; }
44
+
45
+ Attribute SparseTensorEncodingAttr::parse (MLIRContext *context,
46
+ DialectAsmParser &parser, Type type) {
47
+ if (failed (parser.parseLess ()))
48
+ return {};
49
+ DictionaryAttr dict;
50
+ if (failed (parser.parseAttribute (dict)))
51
+ return {};
52
+ if (failed (parser.parseGreater ()))
53
+ return {};
54
+ return SparseTensorEncodingAttr::get (context, dict);
55
+ }
56
+
57
+ void SparseTensorEncodingAttr::print (DialectAsmPrinter &printer) const {
58
+ printer << " sparse<" << getDict () << " >" ;
59
+ }
60
+
61
+ LogicalResult SparseTensorEncodingAttr::verifyEncoding (
62
+ llvm::ArrayRef<int64_t > shape, Type elementType,
63
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
64
+ unsigned size = shape.size ();
65
+ for (const NamedAttribute &attr : getDict ()) {
66
+ if (attr.first == getSparseDimLevelTypeAttrName ()) {
67
+ // Dimension level type verification.
68
+ auto arrayAttr = attr.second .dyn_cast <ArrayAttr>();
69
+ if (!arrayAttr || size != static_cast <int64_t >(arrayAttr.size ()))
70
+ return emitError () << " expected an array of size " << size
71
+ << " for dimension level types" ;
72
+ for (unsigned i = 0 ; i < size; i++) {
73
+ auto strAttr = arrayAttr[i].dyn_cast <StringAttr>();
74
+ if (!strAttr)
75
+ return emitError ()
76
+ << " expected string value in dimension level types" ;
77
+ auto strVal = strAttr.getValue ();
78
+ if (strVal != getDenseDimLevelTypeVal () &&
79
+ strVal != getCompressedDimLevelTypeVal () &&
80
+ strVal != getSingletonDimLevelTypeVal ())
81
+ return emitError () << " unexpected dimension level type: " << strAttr;
82
+ }
83
+ } else if (attr.first == getSparseDimOrderingAttrName ()) {
84
+ // Dimension order verification.
85
+ auto affineAttr = attr.second .dyn_cast <AffineMapAttr>();
86
+ if (!affineAttr)
87
+ return emitError () << " expected an affine map for dimension ordering" ;
88
+ AffineMap map = affineAttr.getValue ();
89
+ if (size != map.getNumResults () || !map.isPermutation ())
90
+ return emitError () << " expected a permutation affine map of size "
91
+ << size << " for dimension ordering" ;
92
+ } else if (attr.first == getSparsePointerBitWidthAttrName () ||
93
+ attr.first == getSparseIndexBitWidthAttrName ()) {
94
+ // Pointer or index bitwidth verification.
95
+ auto intAttr = attr.second .dyn_cast <IntegerAttr>();
96
+ if (!intAttr)
97
+ return emitError () << " expected an integral bitwidth" ;
98
+ switch (intAttr.getInt ()) {
99
+ case 0 :
100
+ case 8 :
101
+ case 16 :
102
+ case 32 :
103
+ case 64 :
104
+ continue ;
105
+ default :
106
+ return emitError () << " unexpected bitwidth: " << intAttr.getInt ();
107
+ }
108
+ } else {
109
+ return emitError () << " unexpected key: " << attr.first .str ();
110
+ }
111
+ }
112
+ return success ();
113
+ }
114
+
115
+ SparseTensorEncodingAttr::DimLevelType
116
+ SparseTensorEncodingAttr::getDimLevelType (unsigned dim) const {
117
+ if (auto value = getDict ().get (getSparseDimLevelTypeAttrName ())) {
118
+ auto strVal =
119
+ value.dyn_cast <ArrayAttr>()[dim].cast <StringAttr>().getValue ();
120
+ if (strVal == getCompressedDimLevelTypeVal ())
121
+ return DimLevelType::Compressed;
122
+ if (strVal == getSingletonDimLevelTypeVal ())
123
+ return DimLevelType::Singleton;
124
+ }
125
+ return DimLevelType::Dense;
126
+ }
127
+
128
+ AffineMap SparseTensorEncodingAttr::getDimOrdering () const {
129
+ if (auto value = getDict ().get (getSparseDimOrderingAttrName ()))
130
+ return value.cast <AffineMapAttr>().getValue ();
131
+ return {};
132
+ }
133
+
134
+ unsigned SparseTensorEncodingAttr::getPointerBitWidth () const {
135
+ if (auto value = getDict ().get (getSparsePointerBitWidthAttrName ()))
136
+ return value.cast <IntegerAttr>().getInt ();
137
+ return 0 ;
138
+ }
139
+
140
+ unsigned SparseTensorEncodingAttr::getIndexBitWidth () const {
141
+ if (auto value = getDict ().get (getSparseIndexBitWidthAttrName ()))
142
+ return value.cast <IntegerAttr>().getInt ();
143
+ return 0 ;
144
+ }
145
+
15
146
// ===----------------------------------------------------------------------===//
16
147
// TensorDialect Dialect Interfaces
17
148
// ===----------------------------------------------------------------------===//
@@ -30,10 +161,38 @@ struct TensorInlinerInterface : public DialectInlinerInterface {
30
161
};
31
162
} // end anonymous namespace
32
163
164
+ // ===----------------------------------------------------------------------===//
165
+ // TensorDialect Methods
166
+ // ===----------------------------------------------------------------------===//
167
+
33
168
void TensorDialect::initialize () {
169
+ addAttributes<
170
+ #define GET_ATTRDEF_LIST
171
+ #include " mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
172
+ >();
34
173
addOperations<
35
174
#define GET_OP_LIST
36
175
#include " mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
37
176
>();
38
177
addInterfaces<TensorInlinerInterface>();
39
178
}
179
+
180
+ Attribute TensorDialect::parseAttribute (DialectAsmParser &parser,
181
+ Type type) const {
182
+ StringRef attrTag;
183
+ if (failed (parser.parseKeyword (&attrTag)))
184
+ return Attribute ();
185
+ Attribute attr;
186
+ auto parseResult =
187
+ generatedAttributeParser (getContext (), parser, attrTag, type, attr);
188
+ if (parseResult.hasValue ())
189
+ return attr;
190
+ parser.emitError (parser.getNameLoc (), " unknown tensor attribute" );
191
+ return Attribute ();
192
+ }
193
+
194
+ void TensorDialect::printAttribute (::mlir::Attribute attr,
195
+ ::mlir::DialectAsmPrinter &printer) const {
196
+ if (succeeded (generatedAttributePrinter (attr, printer)))
197
+ return ;
198
+ }
0 commit comments