@@ -210,9 +210,11 @@ static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
210
210
while (mlir::Operation *addrOp = addr.getDefiningOp ()) {
211
211
if (!isa<mlir::memref::ReinterpretCastOp>(addrOp))
212
212
break ;
213
- indices.push_back (addrOp->getOperand (1 ));
214
213
addr = addrOp->getOperand (0 );
215
214
eraseList.push_back (addrOp);
215
+ // If there is another operand, assume it is the lowered index
216
+ if (addrOp->getNumOperands () == 2 )
217
+ indices.push_back (addrOp->getOperand (1 ));
216
218
}
217
219
base = addr;
218
220
if (indices.size () == 0 )
@@ -1263,6 +1265,23 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
1263
1265
op, newDstType, src, 0 , std::nullopt, std::nullopt);
1264
1266
return mlir::success ();
1265
1267
}
1268
+ case CIR::bitcast: {
1269
+ // clang-format off
1270
+ // %7 = cir.cast(bitcast, %6 : !cir.ptr<!s32i>), !cir.ptr<!cir.array<!s32i x 8192>>
1271
+ // Is lowered as
1272
+ // memref<i32> → memref.reinterpret_cast → memref<8192xi32>
1273
+ // clang-format on
1274
+ auto newDstType = convertTy (dstType);
1275
+ if (!(mlir::isa<mlir::MemRefType>(adaptor.getSrc ().getType ()) &&
1276
+ mlir::isa<mlir::MemRefType>(newDstType)))
1277
+ return op.emitError () << " NYI bitcast from " << op.getSrc ().getType ()
1278
+ << " to " << dstType;
1279
+ auto dstMR = mlir::cast<mlir::MemRefType>(newDstType);
1280
+ auto [strides, offset] = dstMR.getStridesAndOffset ();
1281
+ rewriter.replaceOpWithNewOp <mlir::memref::ReinterpretCastOp>(
1282
+ op, dstMR, src, offset, dstMR.getShape (), strides);
1283
+ return mlir::success ();
1284
+ }
1266
1285
case CIR::int_to_bool: {
1267
1286
auto zero = rewriter.create <cir::ConstantOp>(
1268
1287
src.getLoc (), op.getSrc ().getType (),
0 commit comments