Skip to content

Commit 23c9e8b

Browse files
committed
[mlir][tensors] Introduce attribute interface/attribute for tensor encoding
The new "encoding" field in tensor types so far had no meaning. This revision introduces: 1. an encoding attribute interface in IR: for verification between tensors and encodings in general 2. an attribute in Tensor dialect; #tensor.sparse<dict> + concrete sparse tensors API Active discussion: https://llvm.discourse.group/t/rfc-introduce-a-sparse-tensor-type-to-core-mlir/2944/ Reviewed By: silvas, penpornk, bixia Differential Revision: https://reviews.llvm.org/D101008
1 parent bba7338 commit 23c9e8b

File tree

15 files changed

+417
-2
lines changed

15 files changed

+417
-2
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
add_mlir_dialect(TensorOps tensor)
22
add_mlir_doc(TensorOps TensorOps Dialects/ -gen-dialect-doc)
3+
4+
set(LLVM_TARGET_DEFINITIONS TensorAttrDefs.td)
5+
mlir_tablegen(TensorAttrDefs.h.inc -gen-attrdef-decls)
6+
mlir_tablegen(TensorAttrDefs.cpp.inc -gen-attrdef-defs)
7+
add_public_tablegen_target(MLIRTensorAttrDefsIncGen)

mlir/include/mlir/Dialect/Tensor/IR/Tensor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/IR/Dialect.h"
1414
#include "mlir/IR/OpDefinition.h"
1515
#include "mlir/IR/OpImplementation.h"
16+
#include "mlir/IR/TensorEncoding.h"
1617
#include "mlir/Interfaces/CastInterfaces.h"
1718
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1819
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -23,6 +24,13 @@
2324

2425
#include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
2526

27+
//===----------------------------------------------------------------------===//
28+
// Tensor Dialect Attributes
29+
//===----------------------------------------------------------------------===//
30+
31+
#define GET_ATTRDEF_CLASSES
32+
#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.h.inc"
33+
2634
//===----------------------------------------------------------------------===//
2735
// Tensor Dialect Operations
2836
//===----------------------------------------------------------------------===//
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===-- TensorAttrDefs.td - Tensor Attributes Definitions --*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef TENSOR_ATTRDEFS
10+
#define TENSOR_ATTRDEFS
11+
12+
include "mlir/Dialect/Tensor/IR/TensorBase.td"
13+
include "mlir/IR/TensorEncoding.td"
14+
15+
// All of the Tensor attributes will extend this class.
16+
class Tensor_Attr<string name,
17+
list<Trait> traits = []> : AttrDef<Tensor_Dialect, name, traits>;
18+
19+
// Sparse tensor encoding attribute.
20+
def SparseTensorEncodingAttr : Tensor_Attr<"SparseTensorEncoding",
21+
[ DeclareAttrInterfaceMethods<VerifiableTensorEncoding> ] > {
22+
let mnemonic = "sparse";
23+
24+
let description = [{
25+
An attribute to encode "TACO"-style information (see tensor-compiler.org)
26+
on the sparsity of tensors. The semantics are defined by means of the
27+
methods getDimLevelType(), getDimOrdering(), getPointerType(), and
28+
getIndexType(), documented below. The encoding is eventually used by
29+
a `sparse compiler` pass to generate sparse code fully automatically
30+
for all tensor expressions that involve tensors with a sparse encoding.
31+
Compiler passes that run before this sparse compiler pass need to be
32+
aware of the semantics of tensor types with such an encoding.
33+
}];
34+
35+
// All data is stored in a dictionary, interpreted by the methods below.
36+
let parameters = (
37+
ins
38+
"DictionaryAttr":$dict
39+
);
40+
41+
let extraClassDeclaration = [{
42+
// Dimension level types that define sparse tensors:
43+
// Dense - dimension is dense, every entry is stored
44+
// Compressed - dimension is sparse, only nonzeros are stored
45+
// Singleton - dimension contains single coordinate, no siblings
46+
enum class DimLevelType {
47+
Dense, Compressed, Singleton
48+
};
49+
50+
// Returns the dimension level type in the given dimension `dim`
51+
// of this tensor type. The choices, defined by the `DimLevelType`
52+
// enum, are `dense` (the dimension should be stored in its entirety),
53+
// `compressed` (only non-zero regions or elements should be stored),
54+
// or `singleton` (no sibling elements for parent).
55+
DimLevelType getDimLevelType(unsigned dim) const;
56+
57+
// Returns the dimension order of this tensor type as an AffineMap.
58+
// Unlike dense storage, most sparse storage schemes do not provide
59+
// fast random access. This affine map specifies the order of
60+
// dimensions that should be support by the sparse storage scheme
61+
// (e.g. (i,j) -> (i,j) requests 2-d row-wise and (i,j) -> (j,i)
62+
// requests 2-d column-wise storage).
63+
// TODO: block structure with higher-dim inputs
64+
AffineMap getDimOrdering() const;
65+
66+
// Returns the required bit width for pointer storage. A narrow width
67+
// reduces the memory footprint of overhead storage, as long as the
68+
// width suffices to define the total required range (viz. the maximum
69+
// number of stored entries over all indirection dimensions). The choices
70+
// are `8`, `16`, `32`, `64`, or `0` for a native width.
71+
unsigned getPointerBitWidth() const;
72+
73+
// Returns the required bit width for index storage. A narrow width
74+
// reduces the memory footprint of overhead storage, as long as the
75+
// width suffices to define the total required range (viz. the maximum
76+
// value of each tensor index over all dimensions). The choices are `8`,
77+
// `16`, `32`, `64`, or `0` for a native width.
78+
unsigned getIndexBitWidth() const;
79+
}];
80+
}
81+
82+
#endif // LLVMIR_ATTRDEFS

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define TENSOR_OPS
1111

1212
include "mlir/Dialect/Tensor/IR/TensorBase.td"
13+
include "mlir/Dialect/Tensor/IR/TensorAttrDefs.td"
1314
include "mlir/Interfaces/CastInterfaces.td"
1415
include "mlir/Interfaces/ControlFlowInterfaces.td"
1516
include "mlir/Interfaces/SideEffectInterfaces.td"

mlir/include/mlir/IR/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
2626
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
2727
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
2828

29+
set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
30+
mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
31+
mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)
32+
add_public_tablegen_target(MLIRTensorEncodingIncGen)
33+
2934
add_mlir_doc(BuiltinAttributes BuiltinAttributes Dialects/ -gen-attrdef-doc)
3035
add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc)
3136
add_mlir_doc(BuiltinOps BuiltinOps Dialects/ -gen-op-doc)

mlir/include/mlir/IR/TensorEncoding.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- TensorEncoding.h - MLIR Tensor Encoding Declarations------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_IR_TENSORENCODING_H
10+
#define MLIR_IR_TENSORENCODING_H
11+
12+
#include "mlir/IR/AffineMap.h"
13+
#include "mlir/IR/OpDefinition.h"
14+
15+
//===----------------------------------------------------------------------===//
16+
// Tablegen Type Declarations
17+
//===----------------------------------------------------------------------===//
18+
19+
#include "mlir/IR/TensorEncInterfaces.h.inc"
20+
21+
#endif // MLIR_IR_TENSORENCODING_H
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===- TensorEncoding.td - Tensor encoding interfaces ------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines the interfaces associated with tensor encoding attributes.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_IR_TENSORINTERFACES
14+
#define MLIR_IR_TENSORINTERFACES
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
//===----------------------------------------------------------------------===//
19+
// Attribute interface to verify a tensor encoding.
20+
//===----------------------------------------------------------------------===//
21+
22+
def VerifiableTensorEncoding : AttrInterface<"VerifiableTensorEncoding"> {
23+
let cppNamespace = "::mlir";
24+
let description = [{
25+
Verifies an encoding attribute for a tensor.
26+
}];
27+
let methods = [
28+
InterfaceMethod<
29+
/*desc=*/[{
30+
Verifies the encoding is valid for a tensor type with the
31+
given shape and element type. Generates a diagnostic using
32+
the supplied callback on failure.
33+
}],
34+
/*retTy=*/"::mlir::LogicalResult",
35+
/*methodName=*/"verifyEncoding",
36+
/*args=*/(ins
37+
"ArrayRef<int64_t>":$shape,
38+
"Type":$elementType,
39+
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
40+
>,
41+
];
42+
}
43+
44+
#endif // MLIR_IR_TENSORINTERFACES

mlir/lib/Dialect/Tensor/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTensor
77

88
DEPENDS
99
MLIRTensorOpsIncGen
10+
MLIRTensorAttrDefsIncGen
1011

1112
LINK_COMPONENTS
1213
Core

mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,142 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Tensor/IR/Tensor.h"
10+
#include "mlir/IR/DialectImplementation.h"
1011
#include "mlir/Transforms/InliningUtils.h"
12+
#include "llvm/ADT/TypeSwitch.h"
1113

1214
using namespace mlir;
1315
using namespace mlir::tensor;
1416

17+
//===----------------------------------------------------------------------===//
18+
// TableGen'd Attributes Methods
19+
//===----------------------------------------------------------------------===//
20+
21+
#define GET_ATTRDEF_CLASSES
22+
#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
23+
24+
// Dictionary keys.
25+
static constexpr StringRef getSparseDimLevelTypeAttrName() {
26+
return "sparseDimLevelType";
27+
}
28+
static constexpr StringRef getSparseDimOrderingAttrName() {
29+
return "sparseDimOrdering";
30+
}
31+
static constexpr StringRef getSparsePointerBitWidthAttrName() {
32+
return "sparsePointerBitWidth";
33+
}
34+
static constexpr StringRef getSparseIndexBitWidthAttrName() {
35+
return "sparseIndexBitWidth";
36+
}
37+
38+
// Dictionary values.
39+
static constexpr StringRef getDenseDimLevelTypeVal() { return "dense"; }
40+
static constexpr StringRef getCompressedDimLevelTypeVal() {
41+
return "compressed";
42+
}
43+
static constexpr StringRef getSingletonDimLevelTypeVal() { return "singleton"; }
44+
45+
Attribute SparseTensorEncodingAttr::parse(MLIRContext *context,
46+
DialectAsmParser &parser, Type type) {
47+
if (failed(parser.parseLess()))
48+
return {};
49+
DictionaryAttr dict;
50+
if (failed(parser.parseAttribute(dict)))
51+
return {};
52+
if (failed(parser.parseGreater()))
53+
return {};
54+
return SparseTensorEncodingAttr::get(context, dict);
55+
}
56+
57+
void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const {
58+
printer << "sparse<" << getDict() << ">";
59+
}
60+
61+
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
62+
llvm::ArrayRef<int64_t> shape, Type elementType,
63+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
64+
unsigned size = shape.size();
65+
for (const NamedAttribute &attr : getDict()) {
66+
if (attr.first == getSparseDimLevelTypeAttrName()) {
67+
// Dimension level type verification.
68+
auto arrayAttr = attr.second.dyn_cast<ArrayAttr>();
69+
if (!arrayAttr || size != static_cast<int64_t>(arrayAttr.size()))
70+
return emitError() << "expected an array of size " << size
71+
<< " for dimension level types";
72+
for (unsigned i = 0; i < size; i++) {
73+
auto strAttr = arrayAttr[i].dyn_cast<StringAttr>();
74+
if (!strAttr)
75+
return emitError()
76+
<< "expected string value in dimension level types";
77+
auto strVal = strAttr.getValue();
78+
if (strVal != getDenseDimLevelTypeVal() &&
79+
strVal != getCompressedDimLevelTypeVal() &&
80+
strVal != getSingletonDimLevelTypeVal())
81+
return emitError() << "unexpected dimension level type: " << strAttr;
82+
}
83+
} else if (attr.first == getSparseDimOrderingAttrName()) {
84+
// Dimension order verification.
85+
auto affineAttr = attr.second.dyn_cast<AffineMapAttr>();
86+
if (!affineAttr)
87+
return emitError() << "expected an affine map for dimension ordering";
88+
AffineMap map = affineAttr.getValue();
89+
if (size != map.getNumResults() || !map.isPermutation())
90+
return emitError() << "expected a permutation affine map of size "
91+
<< size << " for dimension ordering";
92+
} else if (attr.first == getSparsePointerBitWidthAttrName() ||
93+
attr.first == getSparseIndexBitWidthAttrName()) {
94+
// Pointer or index bitwidth verification.
95+
auto intAttr = attr.second.dyn_cast<IntegerAttr>();
96+
if (!intAttr)
97+
return emitError() << "expected an integral bitwidth";
98+
switch (intAttr.getInt()) {
99+
case 0:
100+
case 8:
101+
case 16:
102+
case 32:
103+
case 64:
104+
continue;
105+
default:
106+
return emitError() << "unexpected bitwidth: " << intAttr.getInt();
107+
}
108+
} else {
109+
return emitError() << "unexpected key: " << attr.first.str();
110+
}
111+
}
112+
return success();
113+
}
114+
115+
SparseTensorEncodingAttr::DimLevelType
116+
SparseTensorEncodingAttr::getDimLevelType(unsigned dim) const {
117+
if (auto value = getDict().get(getSparseDimLevelTypeAttrName())) {
118+
auto strVal =
119+
value.dyn_cast<ArrayAttr>()[dim].cast<StringAttr>().getValue();
120+
if (strVal == getCompressedDimLevelTypeVal())
121+
return DimLevelType::Compressed;
122+
if (strVal == getSingletonDimLevelTypeVal())
123+
return DimLevelType::Singleton;
124+
}
125+
return DimLevelType::Dense;
126+
}
127+
128+
AffineMap SparseTensorEncodingAttr::getDimOrdering() const {
129+
if (auto value = getDict().get(getSparseDimOrderingAttrName()))
130+
return value.cast<AffineMapAttr>().getValue();
131+
return {};
132+
}
133+
134+
unsigned SparseTensorEncodingAttr::getPointerBitWidth() const {
135+
if (auto value = getDict().get(getSparsePointerBitWidthAttrName()))
136+
return value.cast<IntegerAttr>().getInt();
137+
return 0;
138+
}
139+
140+
unsigned SparseTensorEncodingAttr::getIndexBitWidth() const {
141+
if (auto value = getDict().get(getSparseIndexBitWidthAttrName()))
142+
return value.cast<IntegerAttr>().getInt();
143+
return 0;
144+
}
145+
15146
//===----------------------------------------------------------------------===//
16147
// TensorDialect Dialect Interfaces
17148
//===----------------------------------------------------------------------===//
@@ -30,10 +161,38 @@ struct TensorInlinerInterface : public DialectInlinerInterface {
30161
};
31162
} // end anonymous namespace
32163

164+
//===----------------------------------------------------------------------===//
165+
// TensorDialect Methods
166+
//===----------------------------------------------------------------------===//
167+
33168
void TensorDialect::initialize() {
169+
addAttributes<
170+
#define GET_ATTRDEF_LIST
171+
#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
172+
>();
34173
addOperations<
35174
#define GET_OP_LIST
36175
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
37176
>();
38177
addInterfaces<TensorInlinerInterface>();
39178
}
179+
180+
Attribute TensorDialect::parseAttribute(DialectAsmParser &parser,
181+
Type type) const {
182+
StringRef attrTag;
183+
if (failed(parser.parseKeyword(&attrTag)))
184+
return Attribute();
185+
Attribute attr;
186+
auto parseResult =
187+
generatedAttributeParser(getContext(), parser, attrTag, type, attr);
188+
if (parseResult.hasValue())
189+
return attr;
190+
parser.emitError(parser.getNameLoc(), "unknown tensor attribute");
191+
return Attribute();
192+
}
193+
194+
void TensorDialect::printAttribute(::mlir::Attribute attr,
195+
::mlir::DialectAsmPrinter &printer) const {
196+
if (succeeded(generatedAttributePrinter(attr, printer)))
197+
return;
198+
}

0 commit comments

Comments
 (0)