Skip to content

Commit 1eb5b18

Browse files
authored
[mlir][emitc] Support dense as init value for ShapedType (llvm#144826)
1 parent 19ebfa6 commit 1eb5b18

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,7 +1447,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
14471447
}
14481448
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
14491449
if (auto iType = dyn_cast<IntegerType>(
1450-
cast<TensorType>(dense.getType()).getElementType())) {
1450+
cast<ShapedType>(dense.getType()).getElementType())) {
14511451
os << '{';
14521452
interleaveComma(dense, os, [&](const APInt &val) {
14531453
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
@@ -1456,7 +1456,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
14561456
return success();
14571457
}
14581458
if (auto iType = dyn_cast<IndexType>(
1459-
cast<TensorType>(dense.getType()).getElementType())) {
1459+
cast<ShapedType>(dense.getType()).getElementType())) {
14601460
os << '{';
14611461
interleaveComma(dense, os,
14621462
[&](const APInt &val) { printInt(val, false); });

mlir/test/Target/Cpp/const.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ func.func @emitc_constant() {
1616
%c8 = "emitc.constant"(){value = dense<0> : tensor<i32>} : () -> tensor<i32>
1717
%c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex>
1818
%c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
19+
%c11 = "emitc.constant"(){value = dense<[0, 1]> : !emitc.array<2xindex>} : () -> !emitc.array<2xindex>
20+
%c12 = "emitc.constant"(){value = dense<[0.0, 1.0]> : !emitc.array<2xf32>} : () -> !emitc.array<2xf32>
1921
return
2022
}
2123
// CPP-DEFAULT: void emitc_constant() {
@@ -33,6 +35,8 @@ func.func @emitc_constant() {
3335
// CPP-DEFAULT-NEXT: Tensor<int32_t> [[V8:[^ ]*]] = {0};
3436
// CPP-DEFAULT-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]] = {0, 1};
3537
// CPP-DEFAULT-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
38+
// CPP-DEFAULT-NEXT: size_t [[V11:[^ ]*]][2] = {0, 1};
39+
// CPP-DEFAULT-NEXT: float [[V12:[^ ]*]][2] = {0.0e+00f, 1.000000000e+00f};
3640

3741
// CPP-DECLTOP: void emitc_constant() {
3842
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
@@ -49,6 +53,8 @@ func.func @emitc_constant() {
4953
// CPP-DECLTOP-NEXT: Tensor<int32_t> [[V8:[^ ]*]];
5054
// CPP-DECLTOP-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]];
5155
// CPP-DECLTOP-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]];
56+
// CPP-DECLTOP-NEXT: size_t [[V11:[^ ]*]][2];
57+
// CPP-DECLTOP-NEXT: float [[V12:[^ ]*]][2];
5258
// CPP-DECLTOP-NEXT: [[V0]] = INT_MAX;
5359
// CPP-DECLTOP-NEXT: [[V1]] = 42;
5460
// CPP-DECLTOP-NEXT: [[V2]] = -1;
@@ -63,3 +69,5 @@ func.func @emitc_constant() {
6369
// CPP-DECLTOP-NEXT: [[V8]] = {0};
6470
// CPP-DECLTOP-NEXT: [[V9]] = {0, 1};
6571
// CPP-DECLTOP-NEXT: [[V10]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
72+
// CPP-DECLTOP-NEXT: [[V11]] = {0, 1};
73+
// CPP-DECLTOP-NEXT: [[V12]] = {0.0e+00f, 1.000000000e+00f};

0 commit comments

Comments
 (0)