@@ -211,9 +211,11 @@ static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
211
211
while (mlir::Operation *addrOp = addr.getDefiningOp ()) {
212
212
if (!isa<mlir::memref::ReinterpretCastOp>(addrOp))
213
213
break ;
214
- indices.push_back (addrOp->getOperand (1 ));
215
214
addr = addrOp->getOperand (0 );
216
215
eraseList.push_back (addrOp);
216
+ // If there is another operand, assume it is the lowered index
217
+ if (addrOp->getNumOperands () == 2 )
218
+ indices.push_back (addrOp->getOperand (1 ));
217
219
}
218
220
base = addr;
219
221
if (indices.size () == 0 )
@@ -1262,6 +1264,23 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
1262
1264
op, newDstType, src, 0 , std::nullopt, std::nullopt);
1263
1265
return mlir::success ();
1264
1266
}
1267
+ case CIR::bitcast: {
1268
+ // clang-format off
1269
+ // %7 = cir.cast(bitcast, %6 : !cir.ptr<!s32i>), !cir.ptr<!cir.array<!s32i x 8192>>
1270
+ // Is lowered as
1271
+ // memref<i32> → memref.reinterpret_cast → memref<8192xi32>
1272
+ // clang-format on
1273
+ auto newDstType = convertTy (dstType);
1274
+ if (!(mlir::isa<mlir::MemRefType>(adaptor.getSrc ().getType ()) &&
1275
+ mlir::isa<mlir::MemRefType>(newDstType)))
1276
+ return op.emitError () << " NYI bitcast from " << op.getSrc ().getType ()
1277
+ << " to " << dstType;
1278
+ auto dstMR = mlir::cast<mlir::MemRefType>(newDstType);
1279
+ auto [strides, offset] = dstMR.getStridesAndOffset ();
1280
+ rewriter.replaceOpWithNewOp <mlir::memref::ReinterpretCastOp>(
1281
+ op, dstMR, src, offset, dstMR.getShape (), strides);
1282
+ return mlir::success ();
1283
+ }
1265
1284
case CIR::int_to_bool: {
1266
1285
auto zero = rewriter.create <cir::ConstantOp>(
1267
1286
src.getLoc (), op.getSrc ().getType (),
0 commit comments