Skip to content

Commit 376310e

Browse files
felixdaaslanza
authored andcommitted
[ThroughMLIR] basic printf support (llvm#1687)
This PR is related to llvm#1685 and adds some basic support for the printf function. Limitations: 1. It only works if all variadic params are of basic interger/float type (for more info why memref type operands don't work see llvm#1685) 2. Only works if the format string is definied directly inside the printf function The downside of this PR is also that the handling this edge case adds significant code bloat and reduces readability for the cir.call op lowering (I tried to insert some meanigful comments to improve the readability), but I think its worth to do this so we have some basic printf support (without adding an extra cir operation) until upstream support for variadic functions is added to the func dialect. Also a few more test (which use such a basic form of printf) in the llvm Single Source test suite are working with this PR: before this PR: Testing Time: 4.00s Total Discovered Tests: 1833 Passed : 420 (22.91%) Failed : 10 (0.55%) Executable Missing: 1403 (76.54%) with this PR: Testing Time: 10.29s Total Discovered Tests: 1833 Passed : 458 (24.99%) Failed : 6 (0.33%) Executable Missing: 1369 (74.69%)
1 parent 1c9d016 commit 376310e

File tree

2 files changed

+206
-34
lines changed

2 files changed

+206
-34
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 168 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2424
#include "mlir/Dialect/Func/IR/FuncOps.h"
2525
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
2627
#include "mlir/Dialect/Math/IR/Math.h"
2728
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2829
#include "mlir/Dialect/SCF/IR/SCF.h"
2930
#include "mlir/Dialect/Vector/IR/VectorOps.h"
31+
#include "mlir/IR/BuiltinAttributes.h"
3032
#include "mlir/IR/BuiltinDialect.h"
3133
#include "mlir/IR/BuiltinOps.h"
3234
#include "mlir/IR/BuiltinTypes.h"
@@ -109,9 +111,124 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
109111
if (mlir::failed(
110112
getTypeConverter()->convertTypes(op.getResultTypes(), types)))
111113
return mlir::failure();
112-
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
113-
op, op.getCalleeAttr(), types, adaptor.getOperands());
114-
return mlir::LogicalResult::success();
114+
115+
if (!op.isIndirect()) {
116+
// Currently variadic functions are not supported by the builtin func
117+
// dialect. For now only basic call to printf are supported by using the
118+
// llvmir dialect.
119+
// TODO: remove this and add support for variadic function calls once
120+
// TODO: supported by the func dialect
121+
if (op.getCallee()->equals_insensitive("printf")) {
122+
SmallVector<mlir::Type> operandTypes =
123+
llvm::to_vector(adaptor.getOperands().getTypes());
124+
125+
// Drop the initial memref operand type (we replace the memref format
126+
// string with equivalent llvm.mlir ops)
127+
operandTypes.erase(operandTypes.begin());
128+
129+
// Check that the printf attributes can be used in llvmir dialect (i.e
130+
// they have integer/float type)
131+
if (!llvm::all_of(operandTypes, [](mlir::Type ty) {
132+
return mlir::LLVM::isCompatibleType(ty);
133+
})) {
134+
return op.emitError()
135+
<< "lowering of printf attributes having a type that is "
136+
"converted to memref in cir-to-mlir lowering (e.g. "
137+
"pointers) not supported yet";
138+
}
139+
140+
// Currently only versions of printf are supported where the format
141+
// string is defined inside the printf ==> the lowering of the cir ops
142+
// will match:
143+
// %global = memref.get_global %frm_str
144+
// %* = memref.reinterpret_cast (%global, 0)
145+
if (auto reinterpret_castOP =
146+
mlir::dyn_cast_or_null<mlir::memref::ReinterpretCastOp>(
147+
adaptor.getOperands()[0].getDefiningOp())) {
148+
if (auto getGlobalOp =
149+
mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
150+
reinterpret_castOP->getOperand(0).getDefiningOp())) {
151+
mlir::ModuleOp parentModule = op->getParentOfType<mlir::ModuleOp>();
152+
153+
auto context = rewriter.getContext();
154+
155+
// Find the memref.global op defining the frm_str
156+
auto globalOp = parentModule.lookupSymbol<mlir::memref::GlobalOp>(
157+
getGlobalOp.getNameAttr());
158+
159+
rewriter.setInsertionPoint(globalOp);
160+
161+
// Insert a equivalent llvm.mlir.global
162+
auto initialvalueAttr =
163+
mlir::dyn_cast_or_null<mlir::DenseIntElementsAttr>(
164+
globalOp.getInitialValueAttr());
165+
166+
auto type = mlir::LLVM::LLVMArrayType::get(
167+
mlir::IntegerType::get(context, 8),
168+
initialvalueAttr.getNumElements());
169+
170+
auto llvmglobalOp = rewriter.create<mlir::LLVM::GlobalOp>(
171+
globalOp->getLoc(), type, true, mlir::LLVM::Linkage::Internal,
172+
"printf_format_" + globalOp.getSymName().str(),
173+
initialvalueAttr, 0);
174+
175+
rewriter.setInsertionPoint(getGlobalOp);
176+
177+
// Insert llvmir dialect ops to retrive the !llvm.ptr of the global
178+
auto globalPtrOp = rewriter.create<mlir::LLVM::AddressOfOp>(
179+
getGlobalOp->getLoc(), llvmglobalOp);
180+
181+
mlir::Value cst0 = rewriter.create<mlir::LLVM::ConstantOp>(
182+
getGlobalOp->getLoc(), rewriter.getI8Type(),
183+
rewriter.getIndexAttr(0));
184+
auto gepPtrOp = rewriter.create<mlir::LLVM::GEPOp>(
185+
getGlobalOp->getLoc(),
186+
mlir::LLVM::LLVMPointerType::get(context),
187+
llvmglobalOp.getType(), globalPtrOp,
188+
ArrayRef<mlir::Value>({cst0, cst0}));
189+
190+
mlir::ValueRange operands = adaptor.getOperands();
191+
192+
// Replace the old memref operand with the !llvm.ptr for the frm_str
193+
mlir::SmallVector<mlir::Value> newOperands;
194+
newOperands.push_back(gepPtrOp);
195+
newOperands.append(operands.begin() + 1, operands.end());
196+
197+
// Create the llvmir dialect function type for printf
198+
auto llvmI32Ty = mlir::IntegerType::get(context, 32);
199+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context);
200+
auto llvmFnType =
201+
mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
202+
/*isVarArg=*/true);
203+
204+
rewriter.setInsertionPoint(op);
205+
206+
// Insert an llvm.call op with the updated operands to printf
207+
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
208+
op, llvmFnType, op.getCalleeAttr(), newOperands);
209+
210+
// Cleanup printf frm_str memref ops
211+
rewriter.eraseOp(reinterpret_castOP);
212+
rewriter.eraseOp(getGlobalOp);
213+
rewriter.eraseOp(globalOp);
214+
215+
return mlir::LogicalResult::success();
216+
}
217+
}
218+
219+
return op.emitError()
220+
<< "lowering of printf function with Format-String"
221+
"defined outside of printf is not supported yet";
222+
}
223+
224+
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
225+
op, op.getCalleeAttr(), types, adaptor.getOperands());
226+
return mlir::LogicalResult::success();
227+
228+
} else {
229+
// TODO: support lowering of indirect calls via func.call_indirect op
230+
return op.emitError() << "lowering of indirect calls not supported yet";
231+
}
115232
}
116233
};
117234

@@ -561,43 +678,60 @@ class CIRFuncOpLowering : public mlir::OpConversionPattern<cir::FuncOp> {
561678
mlir::ConversionPatternRewriter &rewriter) const override {
562679

563680
auto fnType = op.getFunctionType();
564-
mlir::TypeConverter::SignatureConversion signatureConversion(
565-
fnType.getNumInputs());
566-
567-
for (const auto &argType : enumerate(fnType.getInputs())) {
568-
auto convertedType = getTypeConverter()->convertType(argType.value());
569-
if (!convertedType)
570-
return mlir::failure();
571-
signatureConversion.addInputs(argType.index(), convertedType);
572-
}
573681

574-
SmallVector<mlir::NamedAttribute, 2> passThroughAttrs;
682+
if (fnType.isVarArg()) {
683+
// TODO: once the func dialect supports variadic functions rewrite this
684+
// For now only insert special handling of printf via the llvmir dialect
685+
if (op.getSymName().equals_insensitive("printf")) {
686+
auto context = rewriter.getContext();
687+
// Create a llvmir dialect function declaration for printf, the
688+
// signature is: i32 (!llvm.ptr, ...)
689+
auto llvmI32Ty = mlir::IntegerType::get(context, 32);
690+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context);
691+
auto llvmFnType =
692+
mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
693+
/*isVarArg=*/true);
694+
auto printfFunc = rewriter.create<mlir::LLVM::LLVMFuncOp>(
695+
op.getLoc(), "printf", llvmFnType);
696+
rewriter.replaceOp(op, printfFunc);
697+
} else {
698+
rewriter.eraseOp(op);
699+
return op.emitError() << "lowering of variadic functions (except "
700+
"printf) not supported yet";
701+
}
702+
} else {
703+
mlir::TypeConverter::SignatureConversion signatureConversion(
704+
fnType.getNumInputs());
705+
706+
for (const auto &argType : enumerate(fnType.getInputs())) {
707+
auto convertedType = typeConverter->convertType(argType.value());
708+
if (!convertedType)
709+
return mlir::failure();
710+
signatureConversion.addInputs(argType.index(), convertedType);
711+
}
575712

576-
if (auto symVisibilityAttr = op.getSymVisibilityAttr())
577-
passThroughAttrs.push_back(
578-
rewriter.getNamedAttr("sym_visibility", symVisibilityAttr));
713+
SmallVector<mlir::NamedAttribute, 2> passThroughAttrs;
579714

580-
SmallVector<mlir::Type> resultTypes;
581-
// Only convert return type if the function is not void
582-
if (!mlir::isa<cir::VoidType>(fnType.getReturnType())) {
583-
auto resultType = getTypeConverter()->convertType(fnType.getReturnType());
584-
if (!resultType)
585-
return mlir::failure();
586-
resultTypes.push_back(resultType);
587-
}
715+
if (auto symVisibilityAttr = op.getSymVisibilityAttr())
716+
passThroughAttrs.push_back(
717+
rewriter.getNamedAttr("sym_visibility", symVisibilityAttr));
588718

589-
auto fn = rewriter.create<mlir::func::FuncOp>(
590-
op.getLoc(), op.getName(),
591-
rewriter.getFunctionType(signatureConversion.getConvertedTypes(),
592-
resultTypes),
593-
passThroughAttrs);
719+
mlir::Type resultType =
720+
getTypeConverter()->convertType(fnType.getReturnType());
721+
auto fn = rewriter.create<mlir::func::FuncOp>(
722+
op.getLoc(), op.getName(),
723+
rewriter.getFunctionType(signatureConversion.getConvertedTypes(),
724+
resultType ? mlir::TypeRange(resultType)
725+
: mlir::TypeRange()),
726+
passThroughAttrs);
594727

595-
if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
596-
&signatureConversion)))
597-
return mlir::failure();
598-
rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
728+
if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter,
729+
&signatureConversion)))
730+
return mlir::failure();
731+
rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
599732

600-
rewriter.eraseOp(op);
733+
rewriter.eraseOp(op);
734+
}
601735
return mlir::LogicalResult::success();
602736
}
603737
};

clang/test/CIR/Lowering/ThroughMLIR/call.c

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,41 @@ int test(void) {
1212
// CHECK: %[[ARG:.+]] = arith.constant 2 : i32
1313
// CHECK-NEXT: call @foo(%[[ARG]]) : (i32) -> ()
1414
// CHECK: }
15+
16+
extern int printf(const char *str, ...);
17+
18+
// CHECK-LABEL: llvm.func @printf(!llvm.ptr, ...) -> i32
19+
// CHECK: llvm.mlir.global internal constant @[[FRMT_STR:.*]](dense<[37, 100, 44, 32, 37, 102, 44, 32, 37, 100, 44, 32, 37, 108, 108, 100, 44, 32, 37, 100, 44, 32, 37, 102, 10, 0]> : tensor<26xi8>) {addr_space = 0 : i32} : !llvm.array<26 x i8>
20+
21+
void testfunc(short s, float X, char C, long long LL, int I, double D) {
22+
printf("%d, %f, %d, %lld, %d, %f\n", s, X, C, LL, I, D);
23+
}
24+
25+
// CHECK: func.func @testfunc(%[[ARG0:.*]]: i16 {{.*}}, %[[ARG1:.*]]: f32 {{.*}}, %[[ARG2:.*]]: i8 {{.*}}, %[[ARG3:.*]]: i64 {{.*}}, %[[ARG4:.*]]: i32 {{.*}}, %[[ARG5:.*]]: f64 {{.*}}) {
26+
// CHECK: %[[ALLOCA_S:.*]] = memref.alloca() {alignment = 2 : i64} : memref<i16>
27+
// CHECK: %[[ALLOCA_X:.*]] = memref.alloca() {alignment = 4 : i64} : memref<f32>
28+
// CHECK: %[[ALLOCA_C:.*]] = memref.alloca() {alignment = 1 : i64} : memref<i8>
29+
// CHECK: %[[ALLOCA_LL:.*]] = memref.alloca() {alignment = 8 : i64} : memref<i64>
30+
// CHECK: %[[ALLOCA_I:.*]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
31+
// CHECK: %[[ALLOCA_D:.*]] = memref.alloca() {alignment = 8 : i64} : memref<f64>
32+
// CHECK: memref.store %[[ARG0]], %[[ALLOCA_S]][] : memref<i16>
33+
// CHECK: memref.store %[[ARG1]], %[[ALLOCA_X]][] : memref<f32>
34+
// CHECK: memref.store %[[ARG2]], %[[ALLOCA_C]][] : memref<i8>
35+
// CHECK: memref.store %[[ARG3]], %[[ALLOCA_LL]][] : memref<i64>
36+
// CHECK: memref.store %[[ARG4]], %[[ALLOCA_I]][] : memref<i32>
37+
// CHECK: memref.store %[[ARG5]], %[[ALLOCA_D]][] : memref<f64>
38+
// CHECK: %[[FRMT_STR_ADDR:.*]] = llvm.mlir.addressof @[[FRMT_STR]] : !llvm.ptr
39+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i8
40+
// CHECK: %[[FRMT_STR_DATA:.*]] = llvm.getelementptr %[[FRMT_STR_ADDR]][%[[C0]], %[[C0]]] : (!llvm.ptr, i8, i8) -> !llvm.ptr, !llvm.array<26 x i8>
41+
// CHECK: %[[S:.*]] = memref.load %[[ALLOCA_S]][] : memref<i16>
42+
// CHECK: %[[S_EXT:.*]] = arith.extsi %3 : i16 to i32
43+
// CHECK: %[[X:.*]] = memref.load %[[ALLOCA_X]][] : memref<f32>
44+
// CHECK: %[[X_EXT:.*]] = arith.extf %5 : f32 to f64
45+
// CHECK: %[[C:.*]] = memref.load %[[ALLOCA_C]][] : memref<i8>
46+
// CHECK: %[[C_EXT:.*]] = arith.extsi %7 : i8 to i32
47+
// CHECK: %[[LL:.*]] = memref.load %[[ALLOCA_LL]][] : memref<i64>
48+
// CHECK: %[[I:.*]] = memref.load %[[ALLOCA_I]][] : memref<i32>
49+
// CHECK: %[[D:.*]] = memref.load %[[ALLOCA_D]][] : memref<f64>
50+
// CHECK: {{.*}} = llvm.call @printf(%[[FRMT_STR_DATA]], %[[S_EXT]], %[[X_EXT]], %[[C_EXT]], %[[LL]], %[[I]], %[[D]]) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr, i32, f64, i32, i64, i32, f64) -> i32
51+
// CHECK: return
52+
// CHECK: }

0 commit comments

Comments
 (0)