Skip to content

Commit 6c8a4f3

Browse files
ftynsememfrob
authored andcommitted
[mlir] speed up construction of LLVM IR constants when possible
The translation to LLVM IR used to construct sequential constants by recurring down to individual elements, creating constant values for them, and wrapping them into aggregate constants in post-order. This is highly inefficient for large constants with known data such as DenseElementsAttr. Use LLVM's ConstantData for the innermost dimension instead. LLVM does seem to support data constants for nested sequential constants so the outer dimensions are still handled recursively. Nevertheless, this speeds up the translation of large constants with equal dimensions by up to 30x. Users are advised to rewrite large constants to use flat types before translating to LLVM IR if more efficiency in translation is necessary. This is not done automatically as the translation is not aware of the expectations of the overall compilation flow about type changes and indexing, in particular for global constants with external linkage. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D109152
1 parent b4e989a commit 6c8a4f3

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,92 @@ static llvm::Type *getInnermostElementType(llvm::Type *type) {
101101
} while (true);
102102
}
103103

104+
/// Convert a dense elements attribute to an LLVM IR constant using its raw data
105+
/// storage if possible. This supports elements attributes of tensor or vector
106+
/// type and avoids constructing separate objects for individual values of the
107+
/// innermost dimension. Constants for other dimensions are still constructed
108+
/// recursively. Returns null if constructing from raw data is not supported for
109+
/// this type, e.g., element type is not a power-of-two-sized primitive. Reports
110+
/// other errors at `loc`.
111+
static llvm::Constant *
112+
convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
113+
llvm::Type *llvmType,
114+
const ModuleTranslation &moduleTranslation) {
115+
if (!denseElementsAttr)
116+
return nullptr;
117+
118+
llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
119+
if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
120+
return nullptr;
121+
122+
// Compute the shape of all dimensions but the innermost. Note that the
123+
// innermost dimension may be that of the vector element type.
124+
ShapedType type = denseElementsAttr.getType();
125+
bool hasVectorElementType = type.getElementType().isa<VectorType>();
126+
unsigned numAggregates =
127+
denseElementsAttr.getNumElements() /
128+
(hasVectorElementType ? 1
129+
: denseElementsAttr.getType().getShape().back());
130+
ArrayRef<int64_t> outerShape = type.getShape();
131+
if (!hasVectorElementType)
132+
outerShape = outerShape.drop_back();
133+
134+
// Handle the case of vector splat, LLVM has special support for it.
135+
if (denseElementsAttr.isSplat() &&
136+
(type.isa<VectorType>() || hasVectorElementType)) {
137+
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
138+
innermostLLVMType, denseElementsAttr.getSplatValue(), loc,
139+
moduleTranslation, /*isTopLevel=*/false);
140+
llvm::Constant *splatVector =
141+
llvm::ConstantDataVector::getSplat(0, splatValue);
142+
SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
143+
ArrayRef<llvm::Constant *> constantsRef = constants;
144+
return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
145+
}
146+
if (denseElementsAttr.isSplat())
147+
return nullptr;
148+
149+
// In case of non-splat, create a constructor for the innermost constant from
150+
// a piece of raw data.
151+
std::function<llvm::Constant *(StringRef)> buildCstData;
152+
if (type.isa<TensorType>()) {
153+
auto vectorElementType = type.getElementType().dyn_cast<VectorType>();
154+
if (vectorElementType && vectorElementType.getRank() == 1) {
155+
buildCstData = [&](StringRef data) {
156+
return llvm::ConstantDataVector::getRaw(
157+
data, vectorElementType.getShape().back(), innermostLLVMType);
158+
};
159+
} else if (!vectorElementType) {
160+
buildCstData = [&](StringRef data) {
161+
return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
162+
innermostLLVMType);
163+
};
164+
}
165+
} else if (type.isa<VectorType>()) {
166+
buildCstData = [&](StringRef data) {
167+
return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
168+
innermostLLVMType);
169+
};
170+
}
171+
if (!buildCstData)
172+
return nullptr;
173+
174+
// Create innermost constants and defer to the default constant creation
175+
// mechanism for other dimensions.
176+
SmallVector<llvm::Constant *> constants;
177+
unsigned aggregateSize = denseElementsAttr.getType().getShape().back() *
178+
(innermostLLVMType->getScalarSizeInBits() / 8);
179+
constants.reserve(numAggregates);
180+
for (unsigned i = 0; i < numAggregates; ++i) {
181+
StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
182+
aggregateSize);
183+
constants.push_back(buildCstData(data));
184+
}
185+
186+
ArrayRef<llvm::Constant *> constantsRef = constants;
187+
return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
188+
}
189+
104190
/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
105191
/// This currently supports integer, floating point, splat and dense element
106192
/// attributes and combinations thereof. Also, an array attribute with two
@@ -178,6 +264,14 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
178264
}
179265
}
180266

267+
// Try using raw elements data if possible.
268+
if (llvm::Constant *result =
269+
convertDenseElementsAttr(loc, attr.dyn_cast<DenseElementsAttr>(),
270+
llvmType, moduleTranslation)) {
271+
return result;
272+
}
273+
274+
// Fall back to element-by-element construction otherwise.
181275
if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
182276
assert(elementsAttr.getType().hasStaticShape());
183277
assert(!elementsAttr.getType().getShape().empty() &&

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,36 @@ llvm.mlir.global internal constant @int_gep() : !llvm.ptr<i32> {
5050
llvm.return %gepinit : !llvm.ptr<i32>
5151
}
5252

53+
// CHECK{LITERAL}: @dense_float_vector = internal global <3 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00>
54+
llvm.mlir.global internal @dense_float_vector(dense<[1.0, 2.0, 3.0]> : vector<3xf32>) : vector<3xf32>
55+
56+
// CHECK{LITERAL}: @splat_float_vector = internal global <3 x float> <float 4.200000e+01, float 4.200000e+01, float 4.200000e+01>
57+
llvm.mlir.global internal @splat_float_vector(dense<42.0> : vector<3xf32>) : vector<3xf32>
58+
59+
// CHECK{LITERAL}: @dense_double_vector = internal global <3 x double> <double 1.000000e+00, double 2.000000e+00, double 3.000000e+00>
60+
llvm.mlir.global internal @dense_double_vector(dense<[1.0, 2.0, 3.0]> : vector<3xf64>) : vector<3xf64>
61+
62+
// CHECK{LITERAL}: @splat_double_vector = internal global <3 x double> <double 4.200000e+01, double 4.200000e+01, double 4.200000e+01>
63+
llvm.mlir.global internal @splat_double_vector(dense<42.0> : vector<3xf64>) : vector<3xf64>
64+
65+
// CHECK{LITERAL}: @dense_i64_vector = internal global <3 x i64> <i64 1, i64 2, i64 3>
66+
llvm.mlir.global internal @dense_i64_vector(dense<[1, 2, 3]> : vector<3xi64>) : vector<3xi64>
67+
68+
// CHECK{LITERAL}: @splat_i64_vector = internal global <3 x i64> <i64 42, i64 42, i64 42>
69+
llvm.mlir.global internal @splat_i64_vector(dense<42> : vector<3xi64>) : vector<3xi64>
70+
71+
// CHECK{LITERAL}: @dense_float_vector_2d = internal global [2 x <2 x float>] [<2 x float> <float 1.000000e+00, float 2.000000e+00>, <2 x float> <float 3.000000e+00, float 4.000000e+00>]
72+
llvm.mlir.global internal @dense_float_vector_2d(dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>) : !llvm.array<2 x vector<2xf32>>
73+
74+
// CHECK{LITERAL}: @splat_float_vector_2d = internal global [2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>]
75+
llvm.mlir.global internal @splat_float_vector_2d(dense<42.0> : vector<2x2xf32>) : !llvm.array<2 x vector<2xf32>>
76+
77+
// CHECK{LITERAL}: @dense_float_vector_3d = internal global [2 x [2 x <2 x float>]] [[2 x <2 x float>] [<2 x float> <float 1.000000e+00, float 2.000000e+00>, <2 x float> <float 3.000000e+00, float 4.000000e+00>], [2 x <2 x float>] [<2 x float> <float 5.000000e+00, float 6.000000e+00>, <2 x float> <float 7.000000e+00, float 8.000000e+00>]]
78+
llvm.mlir.global internal @dense_float_vector_3d(dense<[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]> : vector<2x2x2xf32>) : !llvm.array<2 x !llvm.array<2 x vector<2xf32>>>
79+
80+
// CHECK{LITERAL}: @splat_float_vector_3d = internal global [2 x [2 x <2 x float>]] [[2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>], [2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>]]
81+
llvm.mlir.global internal @splat_float_vector_3d(dense<42.0> : vector<2x2x2xf32>) : !llvm.array<2 x !llvm.array<2 x vector<2xf32>>>
82+
5383
//
5484
// Linkage attribute.
5585
//
@@ -67,7 +97,7 @@ llvm.mlir.global weak @weak(42 : i32) : i32
6797
// CHECK: @common = common global i32 0
6898
llvm.mlir.global common @common(0 : i32) : i32
6999
// CHECK: @appending = appending global [3 x i32] [i32 1, i32 2, i32 3]
70-
llvm.mlir.global appending @appending(dense<[1,2,3]> : vector<3xi32>) : !llvm.array<3xi32>
100+
llvm.mlir.global appending @appending(dense<[1,2,3]> : tensor<3xi32>) : !llvm.array<3xi32>
71101
// CHECK: @extern_weak = extern_weak global i32
72102
llvm.mlir.global extern_weak @extern_weak() : i32
73103
// CHECK: @linkonce_odr = linkonce_odr global i32 42

0 commit comments

Comments
 (0)