|
21 | 21 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
22 | 22 | #include "mlir/IR/BuiltinDialect.h" |
23 | 23 | #include "mlir/IR/BuiltinOps.h" |
| 24 | +#include "mlir/IR/Types.h" |
24 | 25 | #include "mlir/Pass/Pass.h" |
25 | 26 | #include "mlir/Pass/PassManager.h" |
26 | 27 | #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" |
@@ -1193,6 +1194,86 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite( |
1193 | 1194 | return mlir::LogicalResult::success(); |
1194 | 1195 | } |
1195 | 1196 |
|
| 1197 | +/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind. |
| 1198 | +static mlir::LLVM::ICmpPredicate |
| 1199 | +convertCmpKindToICmpPredicate(cir::CmpOpKind kind, bool isSigned) { |
| 1200 | + using CIR = cir::CmpOpKind; |
| 1201 | + using LLVMICmp = mlir::LLVM::ICmpPredicate; |
| 1202 | + switch (kind) { |
| 1203 | + case CIR::eq: |
| 1204 | + return LLVMICmp::eq; |
| 1205 | + case CIR::ne: |
| 1206 | + return LLVMICmp::ne; |
| 1207 | + case CIR::lt: |
| 1208 | + return (isSigned ? LLVMICmp::slt : LLVMICmp::ult); |
| 1209 | + case CIR::le: |
| 1210 | + return (isSigned ? LLVMICmp::sle : LLVMICmp::ule); |
| 1211 | + case CIR::gt: |
| 1212 | + return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt); |
| 1213 | + case CIR::ge: |
| 1214 | + return (isSigned ? LLVMICmp::sge : LLVMICmp::uge); |
| 1215 | + } |
| 1216 | + llvm_unreachable("Unknown CmpOpKind"); |
| 1217 | +} |
| 1218 | + |
| 1219 | +/// Convert from a CIR comparison kind to an LLVM IR floating-point comparison |
| 1220 | +/// kind. |
| 1221 | +static mlir::LLVM::FCmpPredicate |
| 1222 | +convertCmpKindToFCmpPredicate(cir::CmpOpKind kind) { |
| 1223 | + using CIR = cir::CmpOpKind; |
| 1224 | + using LLVMFCmp = mlir::LLVM::FCmpPredicate; |
| 1225 | + switch (kind) { |
| 1226 | + case CIR::eq: |
| 1227 | + return LLVMFCmp::oeq; |
| 1228 | + case CIR::ne: |
| 1229 | + return LLVMFCmp::une; |
| 1230 | + case CIR::lt: |
| 1231 | + return LLVMFCmp::olt; |
| 1232 | + case CIR::le: |
| 1233 | + return LLVMFCmp::ole; |
| 1234 | + case CIR::gt: |
| 1235 | + return LLVMFCmp::ogt; |
| 1236 | + case CIR::ge: |
| 1237 | + return LLVMFCmp::oge; |
| 1238 | + } |
| 1239 | + llvm_unreachable("Unknown CmpOpKind"); |
| 1240 | +} |
| 1241 | + |
| 1242 | +mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( |
| 1243 | + cir::CmpOp cmpOp, OpAdaptor adaptor, |
| 1244 | + mlir::ConversionPatternRewriter &rewriter) const { |
| 1245 | + mlir::Type type = cmpOp.getLhs().getType(); |
| 1246 | + |
| 1247 | + assert(!cir::MissingFeatures::dataMemberType()); |
| 1248 | + assert(!cir::MissingFeatures::methodType()); |
| 1249 | + |
| 1250 | + // Lower to LLVM comparison op. |
| 1251 | + if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) { |
| 1252 | + bool isSigned = mlir::isa<cir::IntType>(type) |
| 1253 | + ? mlir::cast<cir::IntType>(type).isSigned() |
| 1254 | + : mlir::cast<mlir::IntegerType>(type).isSigned(); |
| 1255 | + mlir::LLVM::ICmpPredicate kind = |
| 1256 | + convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned); |
| 1257 | + rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( |
| 1258 | + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); |
| 1259 | + } else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) { |
| 1260 | + mlir::LLVM::ICmpPredicate kind = |
| 1261 | + convertCmpKindToICmpPredicate(cmpOp.getKind(), |
| 1262 | + /* isSigned=*/false); |
| 1263 | + rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( |
| 1264 | + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); |
| 1265 | + } else if (mlir::isa<cir::CIRFPTypeInterface>(type)) { |
| 1266 | + mlir::LLVM::FCmpPredicate kind = |
| 1267 | + convertCmpKindToFCmpPredicate(cmpOp.getKind()); |
| 1268 | + rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>( |
| 1269 | + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); |
| 1270 | + } else { |
| 1271 | + return cmpOp.emitError() << "unsupported type for CmpOp: " << type; |
| 1272 | + } |
| 1273 | + |
| 1274 | + return mlir::success(); |
| 1275 | +} |
| 1276 | + |
1196 | 1277 | static void prepareTypeConverter(mlir::LLVMTypeConverter &converter, |
1197 | 1278 | mlir::DataLayout &dataLayout) { |
1198 | 1279 | converter.addConversion([&](cir::PointerType type) -> mlir::Type { |
@@ -1334,6 +1415,7 @@ void ConvertCIRToLLVMPass::runOnOperation() { |
1334 | 1415 | CIRToLLVMBinOpLowering, |
1335 | 1416 | CIRToLLVMBrCondOpLowering, |
1336 | 1417 | CIRToLLVMBrOpLowering, |
| 1418 | + CIRToLLVMCmpOpLowering, |
1337 | 1419 | CIRToLLVMConstantOpLowering, |
1338 | 1420 | CIRToLLVMFuncOpLowering, |
1339 | 1421 | CIRToLLVMTrapOpLowering, |
|
0 commit comments