Skip to content

Commit ab35df7

Browse files
committed
[CIR][Lowering][MLIR] Lower cir.cast(bitcast) between !cir.ptr
1 parent da4dbe2 commit ab35df7

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,11 @@ static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
211211
while (mlir::Operation *addrOp = addr.getDefiningOp()) {
212212
if (!isa<mlir::memref::ReinterpretCastOp>(addrOp))
213213
break;
214-
indices.push_back(addrOp->getOperand(1));
215214
addr = addrOp->getOperand(0);
216215
eraseList.push_back(addrOp);
216+
// If there is another operand, assume it is the lowered index
217+
if (addrOp->getNumOperands() == 2)
218+
indices.push_back(addrOp->getOperand(1));
217219
}
218220
base = addr;
219221
if (indices.size() == 0)
@@ -1262,6 +1264,23 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
12621264
op, newDstType, src, 0, std::nullopt, std::nullopt);
12631265
return mlir::success();
12641266
}
1267+
case CIR::bitcast: {
1268+
// clang-format off
1269+
// %7 = cir.cast(bitcast, %6 : !cir.ptr<!s32i>), !cir.ptr<!cir.array<!s32i x 8192>>
1270+
// Is lowered as
1271+
// memref<i32> → memref.reinterpret_cast → memref<8192xi32>
1272+
// clang-format on
1273+
auto newDstType = convertTy(dstType);
1274+
if (!(mlir::isa<mlir::MemRefType>(adaptor.getSrc().getType()) &&
1275+
mlir::isa<mlir::MemRefType>(newDstType)))
1276+
return op.emitError() << "NYI bitcast from " << op.getSrc().getType()
1277+
<< " to " << dstType;
1278+
auto dstMR = mlir::cast<mlir::MemRefType>(newDstType);
1279+
auto [strides, offset] = dstMR.getStridesAndOffset();
1280+
rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>(
1281+
op, dstMR, src, offset, dstMR.getShape(), strides);
1282+
return mlir::success();
1283+
}
12651284
case CIR::int_to_bool: {
12661285
auto zero = rewriter.create<cir::ConstantOp>(
12671286
src.getLoc(), op.getSrc().getType(),

0 commit comments

Comments
 (0)