@@ -2896,6 +2896,18 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
2896
2896
}
2897
2897
}
2898
2898
case clang::BinaryOperator::Opcode::BO_Add: {
2899
+ auto emitSubindex = [&](auto mr, auto ptradd) {
2900
+ auto mt = mr.getType ().template dyn_cast <mlir::MemRefType>();
2901
+ auto shape = std::vector<int64_t >(mt.getShape ());
2902
+ shape[0 ] = -1 ;
2903
+ auto mt0 = mlir::MemRefType::get (shape, mt.getElementType (),
2904
+ MemRefLayoutAttrInterface (),
2905
+ mt.getMemorySpace ());
2906
+ ptradd = castToIndex (loc, ptradd);
2907
+ return ValueCategory (
2908
+ builder.create <polygeist::SubIndexOp>(loc, mt0, mr, ptradd),
2909
+ /* isReference*/ false );
2910
+ };
2899
2911
if (auto cty = dyn_cast<clang::ComplexType>(BO->getType ())) {
2900
2912
mlir::Value real =
2901
2913
builder.create <AddFOp>(loc, getComplexPart (loc, lhs.val , 0 ),
@@ -2911,17 +2923,10 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
2911
2923
if (lhs_v.getType ().isa <mlir::FloatType>()) {
2912
2924
return ValueCategory (builder.create <AddFOp>(loc, lhs_v, rhs_v),
2913
2925
/* isReference*/ false );
2914
- } else if (auto mt = lhs_v.getType ().dyn_cast <mlir::MemRefType>()) {
2915
- auto shape = std::vector<int64_t >(mt.getShape ());
2916
- shape[0 ] = -1 ;
2917
- auto mt0 = mlir::MemRefType::get (shape, mt.getElementType (),
2918
- MemRefLayoutAttrInterface (),
2919
- mt.getMemorySpace ());
2920
- auto ptradd = rhs_v;
2921
- ptradd = castToIndex (loc, ptradd);
2922
- return ValueCategory (
2923
- builder.create <polygeist::SubIndexOp>(loc, mt0, lhs_v, ptradd),
2924
- /* isReference*/ false );
2926
+ } else if (lhs_v.getType ().isa <mlir::MemRefType>()) {
2927
+ return emitSubindex (lhs_v, rhs_v);
2928
+ } else if (rhs_v.getType ().isa <mlir::MemRefType>()) {
2929
+ return emitSubindex (rhs_v, lhs_v);
2925
2930
} else if (auto pt =
2926
2931
lhs_v.getType ().dyn_cast <mlir::LLVM::LLVMPointerType>()) {
2927
2932
return ValueCategory (
0 commit comments