|
23 | 23 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
24 | 24 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
25 | 25 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 26 | +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
26 | 27 | #include "mlir/Dialect/Math/IR/Math.h" |
27 | 28 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
28 | 29 | #include "mlir/Dialect/SCF/IR/SCF.h" |
29 | 30 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 31 | +#include "mlir/IR/BuiltinAttributes.h" |
30 | 32 | #include "mlir/IR/BuiltinDialect.h" |
31 | 33 | #include "mlir/IR/BuiltinOps.h" |
32 | 34 | #include "mlir/IR/BuiltinTypes.h" |
@@ -109,9 +111,124 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> { |
109 | 111 | if (mlir::failed( |
110 | 112 | getTypeConverter()->convertTypes(op.getResultTypes(), types))) |
111 | 113 | return mlir::failure(); |
112 | | - rewriter.replaceOpWithNewOp<mlir::func::CallOp>( |
113 | | - op, op.getCalleeAttr(), types, adaptor.getOperands()); |
114 | | - return mlir::LogicalResult::success(); |
| 114 | + |
| 115 | + if (!op.isIndirect()) { |
| 116 | + // Currently variadic functions are not supported by the builtin func |
| 117 | + // dialect. For now only basic call to printf are supported by using the |
| 118 | + // llvmir dialect. |
| 119 | + // TODO: remove this and add support for variadic function calls once |
| 120 | + // TODO: supported by the func dialect |
| 121 | + if (op.getCallee()->equals_insensitive("printf")) { |
| 122 | + SmallVector<mlir::Type> operandTypes = |
| 123 | + llvm::to_vector(adaptor.getOperands().getTypes()); |
| 124 | + |
| 125 | + // Drop the initial memref operand type (we replace the memref format |
| 126 | + // string with equivalent llvm.mlir ops) |
| 127 | + operandTypes.erase(operandTypes.begin()); |
| 128 | + |
| 129 | + // Check that the printf attributes can be used in llvmir dialect (i.e |
| 130 | + // they have integer/float type) |
| 131 | + if (!llvm::all_of(operandTypes, [](mlir::Type ty) { |
| 132 | + return mlir::LLVM::isCompatibleType(ty); |
| 133 | + })) { |
| 134 | + return op.emitError() |
| 135 | + << "lowering of printf attributes having a type that is " |
| 136 | + "converted to memref in cir-to-mlir lowering (e.g. " |
| 137 | + "pointers) not supported yet"; |
| 138 | + } |
| 139 | + |
| 140 | + // Currently only versions of printf are supported where the format |
| 141 | + // string is defined inside the printf ==> the lowering of the cir ops |
| 142 | + // will match: |
| 143 | + // %global = memref.get_global %frm_str |
| 144 | + // %* = memref.reinterpret_cast (%global, 0) |
| 145 | + if (auto reinterpret_castOP = |
| 146 | + mlir::dyn_cast_or_null<mlir::memref::ReinterpretCastOp>( |
| 147 | + adaptor.getOperands()[0].getDefiningOp())) { |
| 148 | + if (auto getGlobalOp = |
| 149 | + mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>( |
| 150 | + reinterpret_castOP->getOperand(0).getDefiningOp())) { |
| 151 | + mlir::ModuleOp parentModule = op->getParentOfType<mlir::ModuleOp>(); |
| 152 | + |
| 153 | + auto context = rewriter.getContext(); |
| 154 | + |
| 155 | + // Find the memref.global op defining the frm_str |
| 156 | + auto globalOp = parentModule.lookupSymbol<mlir::memref::GlobalOp>( |
| 157 | + getGlobalOp.getNameAttr()); |
| 158 | + |
| 159 | + rewriter.setInsertionPoint(globalOp); |
| 160 | + |
| 161 | + // Insert a equivalent llvm.mlir.global |
| 162 | + auto initialvalueAttr = |
| 163 | + mlir::dyn_cast_or_null<mlir::DenseIntElementsAttr>( |
| 164 | + globalOp.getInitialValueAttr()); |
| 165 | + |
| 166 | + auto type = mlir::LLVM::LLVMArrayType::get( |
| 167 | + mlir::IntegerType::get(context, 8), |
| 168 | + initialvalueAttr.getNumElements()); |
| 169 | + |
| 170 | + auto llvmglobalOp = rewriter.create<mlir::LLVM::GlobalOp>( |
| 171 | + globalOp->getLoc(), type, true, mlir::LLVM::Linkage::Internal, |
| 172 | + "printf_format_" + globalOp.getSymName().str(), |
| 173 | + initialvalueAttr, 0); |
| 174 | + |
| 175 | + rewriter.setInsertionPoint(getGlobalOp); |
| 176 | + |
| 177 | + // Insert llvmir dialect ops to retrive the !llvm.ptr of the global |
| 178 | + auto globalPtrOp = rewriter.create<mlir::LLVM::AddressOfOp>( |
| 179 | + getGlobalOp->getLoc(), llvmglobalOp); |
| 180 | + |
| 181 | + mlir::Value cst0 = rewriter.create<mlir::LLVM::ConstantOp>( |
| 182 | + getGlobalOp->getLoc(), rewriter.getI8Type(), |
| 183 | + rewriter.getIndexAttr(0)); |
| 184 | + auto gepPtrOp = rewriter.create<mlir::LLVM::GEPOp>( |
| 185 | + getGlobalOp->getLoc(), |
| 186 | + mlir::LLVM::LLVMPointerType::get(context), |
| 187 | + llvmglobalOp.getType(), globalPtrOp, |
| 188 | + ArrayRef<mlir::Value>({cst0, cst0})); |
| 189 | + |
| 190 | + mlir::ValueRange operands = adaptor.getOperands(); |
| 191 | + |
| 192 | + // Replace the old memref operand with the !llvm.ptr for the frm_str |
| 193 | + mlir::SmallVector<mlir::Value> newOperands; |
| 194 | + newOperands.push_back(gepPtrOp); |
| 195 | + newOperands.append(operands.begin() + 1, operands.end()); |
| 196 | + |
| 197 | + // Create the llvmir dialect function type for printf |
| 198 | + auto llvmI32Ty = mlir::IntegerType::get(context, 32); |
| 199 | + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context); |
| 200 | + auto llvmFnType = |
| 201 | + mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, |
| 202 | + /*isVarArg=*/true); |
| 203 | + |
| 204 | + rewriter.setInsertionPoint(op); |
| 205 | + |
| 206 | + // Insert an llvm.call op with the updated operands to printf |
| 207 | + rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( |
| 208 | + op, llvmFnType, op.getCalleeAttr(), newOperands); |
| 209 | + |
| 210 | + // Cleanup printf frm_str memref ops |
| 211 | + rewriter.eraseOp(reinterpret_castOP); |
| 212 | + rewriter.eraseOp(getGlobalOp); |
| 213 | + rewriter.eraseOp(globalOp); |
| 214 | + |
| 215 | + return mlir::LogicalResult::success(); |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + return op.emitError() |
| 220 | + << "lowering of printf function with Format-String" |
| 221 | + "defined outside of printf is not supported yet"; |
| 222 | + } |
| 223 | + |
| 224 | + rewriter.replaceOpWithNewOp<mlir::func::CallOp>( |
| 225 | + op, op.getCalleeAttr(), types, adaptor.getOperands()); |
| 226 | + return mlir::LogicalResult::success(); |
| 227 | + |
| 228 | + } else { |
| 229 | + // TODO: support lowering of indirect calls via func.call_indirect op |
| 230 | + return op.emitError() << "lowering of indirect calls not supported yet"; |
| 231 | + } |
115 | 232 | } |
116 | 233 | }; |
117 | 234 |
|
@@ -561,43 +678,60 @@ class CIRFuncOpLowering : public mlir::OpConversionPattern<cir::FuncOp> { |
561 | 678 | mlir::ConversionPatternRewriter &rewriter) const override { |
562 | 679 |
|
563 | 680 | auto fnType = op.getFunctionType(); |
564 | | - mlir::TypeConverter::SignatureConversion signatureConversion( |
565 | | - fnType.getNumInputs()); |
566 | | - |
567 | | - for (const auto &argType : enumerate(fnType.getInputs())) { |
568 | | - auto convertedType = getTypeConverter()->convertType(argType.value()); |
569 | | - if (!convertedType) |
570 | | - return mlir::failure(); |
571 | | - signatureConversion.addInputs(argType.index(), convertedType); |
572 | | - } |
573 | 681 |
|
574 | | - SmallVector<mlir::NamedAttribute, 2> passThroughAttrs; |
| 682 | + if (fnType.isVarArg()) { |
| 683 | + // TODO: once the func dialect supports variadic functions rewrite this |
| 684 | + // For now only insert special handling of printf via the llvmir dialect |
| 685 | + if (op.getSymName().equals_insensitive("printf")) { |
| 686 | + auto context = rewriter.getContext(); |
| 687 | + // Create a llvmir dialect function declaration for printf, the |
| 688 | + // signature is: i32 (!llvm.ptr, ...) |
| 689 | + auto llvmI32Ty = mlir::IntegerType::get(context, 32); |
| 690 | + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context); |
| 691 | + auto llvmFnType = |
| 692 | + mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, |
| 693 | + /*isVarArg=*/true); |
| 694 | + auto printfFunc = rewriter.create<mlir::LLVM::LLVMFuncOp>( |
| 695 | + op.getLoc(), "printf", llvmFnType); |
| 696 | + rewriter.replaceOp(op, printfFunc); |
| 697 | + } else { |
| 698 | + rewriter.eraseOp(op); |
| 699 | + return op.emitError() << "lowering of variadic functions (except " |
| 700 | + "printf) not supported yet"; |
| 701 | + } |
| 702 | + } else { |
| 703 | + mlir::TypeConverter::SignatureConversion signatureConversion( |
| 704 | + fnType.getNumInputs()); |
| 705 | + |
| 706 | + for (const auto &argType : enumerate(fnType.getInputs())) { |
| 707 | + auto convertedType = typeConverter->convertType(argType.value()); |
| 708 | + if (!convertedType) |
| 709 | + return mlir::failure(); |
| 710 | + signatureConversion.addInputs(argType.index(), convertedType); |
| 711 | + } |
575 | 712 |
|
576 | | - if (auto symVisibilityAttr = op.getSymVisibilityAttr()) |
577 | | - passThroughAttrs.push_back( |
578 | | - rewriter.getNamedAttr("sym_visibility", symVisibilityAttr)); |
| 713 | + SmallVector<mlir::NamedAttribute, 2> passThroughAttrs; |
579 | 714 |
|
580 | | - SmallVector<mlir::Type> resultTypes; |
581 | | - // Only convert return type if the function is not void |
582 | | - if (!mlir::isa<cir::VoidType>(fnType.getReturnType())) { |
583 | | - auto resultType = getTypeConverter()->convertType(fnType.getReturnType()); |
584 | | - if (!resultType) |
585 | | - return mlir::failure(); |
586 | | - resultTypes.push_back(resultType); |
587 | | - } |
| 715 | + if (auto symVisibilityAttr = op.getSymVisibilityAttr()) |
| 716 | + passThroughAttrs.push_back( |
| 717 | + rewriter.getNamedAttr("sym_visibility", symVisibilityAttr)); |
588 | 718 |
|
589 | | - auto fn = rewriter.create<mlir::func::FuncOp>( |
590 | | - op.getLoc(), op.getName(), |
591 | | - rewriter.getFunctionType(signatureConversion.getConvertedTypes(), |
592 | | - resultTypes), |
593 | | - passThroughAttrs); |
| 719 | + mlir::Type resultType = |
| 720 | + getTypeConverter()->convertType(fnType.getReturnType()); |
| 721 | + auto fn = rewriter.create<mlir::func::FuncOp>( |
| 722 | + op.getLoc(), op.getName(), |
| 723 | + rewriter.getFunctionType(signatureConversion.getConvertedTypes(), |
| 724 | + resultType ? mlir::TypeRange(resultType) |
| 725 | + : mlir::TypeRange()), |
| 726 | + passThroughAttrs); |
594 | 727 |
|
595 | | - if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(), |
596 | | - &signatureConversion))) |
597 | | - return mlir::failure(); |
598 | | - rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); |
| 728 | + if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter, |
| 729 | + &signatureConversion))) |
| 730 | + return mlir::failure(); |
| 731 | + rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); |
599 | 732 |
|
600 | | - rewriter.eraseOp(op); |
| 733 | + rewriter.eraseOp(op); |
| 734 | + } |
601 | 735 | return mlir::LogicalResult::success(); |
602 | 736 | } |
603 | 737 | }; |
|
0 commit comments