Skip to content

Commit 5aeb590

Browse files
author
MengmengSun
committed
Fix element type of target attributes in oneToOneRewrite when converting to llvm
1 parent eddd342 commit 5aeb590

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,35 @@ LogicalResult LLVM::detail::oneToOneRewrite(
330330
return failure();
331331
}
332332

333+
// If the targetAttrs contains DenseElementsAttr,
334+
// and the element type of the DenseElementsAttr and result type is
335+
// inconsistent after the conversion of result types, we need to convert the
336+
// element type of the DenseElementsAttr to the target type by creating a new
337+
// DenseElementsAttr with the converted element type, and use the new
338+
// DenseElementsAttr to replace the old one in the targetAttrs
339+
SmallVector<NamedAttribute> convertedAttrs;
340+
for (auto attr : targetAttrs) {
341+
if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue())) {
342+
VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType());
343+
if (vectorType) {
344+
auto convertedElementType =
345+
typeConverter.convertType(vectorType.getElementType());
346+
VectorType convertedVectorType =
347+
VectorType::get(vectorType.getShape(), convertedElementType,
348+
vectorType.getScalableDims());
349+
convertedAttrs.emplace_back(
350+
attr.getName(), DenseElementsAttr::getFromRawBuffer(
351+
convertedVectorType, denseAttr.getRawData()));
352+
}
353+
} else {
354+
convertedAttrs.push_back(attr);
355+
}
356+
}
357+
333358
// Create the operation through state since we don't know its C++ type.
334359
Operation *newOp =
335360
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
336-
resultTypes, targetAttrs);
361+
resultTypes, convertedAttrs);
337362

338363
setNativeProperties(newOp, overflowFlags);
339364

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ func.func @fcmp(f32, f32) -> () {
428428

429429
// CHECK-LABEL: @index_vector
430430
func.func @index_vector(%arg0: vector<4xindex>) {
431-
// CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64>
431+
// CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xi64>) : vector<4xi64>
432432
%0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
433433
// CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64>
434434
%1 = arith.addi %arg0, %0 : vector<4xindex>
@@ -437,6 +437,21 @@ func.func @index_vector(%arg0: vector<4xindex>) {
437437

438438
// -----
439439

440+
// CHECK-LABEL: @f8E4M3FN_vector
441+
func.func @f8E4M3FN_vector() -> vector<4xf8E4M3FN> {
442+
// CHECK: %[[CST0:.*]] = llvm.mlir.constant(dense<0> : vector<4xi8>) : vector<4xi8>
443+
%0 = arith.constant dense<0.000000e+00> : vector<4xf8E4M3FN>
444+
// CHECK: %[[CST1:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
445+
%1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf8E4M3FN>
446+
// CHECK: %[[V:.*]] = llvm.mlir.constant(dense<[56, 64, 68, 72]> : vector<4xi8>) : vector<4xi8>
447+
%2 = arith.addf %0, %1 : vector<4xf8E4M3FN>
448+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V]] : vector<4xi8> to vector<4xf8E4M3FN>
449+
// CHECK-NEXT: return %[[RES]] : vector<4xf8E4M3FN>
450+
func.return %2 : vector<4xf8E4M3FN>
451+
}
452+
453+
// -----
454+
440455
// CHECK-LABEL: @bitcast_1d
441456
func.func @bitcast_1d(%arg0: vector<2xf32>) {
442457
// CHECK: llvm.bitcast %{{.*}} : vector<2xf32> to vector<2xi32>

0 commit comments

Comments
 (0)