Skip to content

Commit 0a66217

Browse files
committed
[CIR][Lowering][MLIR] Lower cir.cast(bitcast) between !cir.ptr
1 parent 6930b3d commit 0a66217

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,11 @@ static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
210210
while (mlir::Operation *addrOp = addr.getDefiningOp()) {
211211
if (!isa<mlir::memref::ReinterpretCastOp>(addrOp))
212212
break;
213-
indices.push_back(addrOp->getOperand(1));
214213
addr = addrOp->getOperand(0);
215214
eraseList.push_back(addrOp);
215+
// If there is another operand, assume it is the lowered index
216+
if (addrOp->getNumOperands() == 2)
217+
indices.push_back(addrOp->getOperand(1));
216218
}
217219
base = addr;
218220
if (indices.size() == 0)
@@ -1263,6 +1265,23 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
12631265
op, newDstType, src, 0, std::nullopt, std::nullopt);
12641266
return mlir::success();
12651267
}
1268+
case CIR::bitcast: {
1269+
// clang-format off
1270+
// %7 = cir.cast(bitcast, %6 : !cir.ptr<!s32i>), !cir.ptr<!cir.array<!s32i x 8192>>
1271+
// Is lowered as
1272+
// memref<i32> → memref.reinterpret_cast → memref<8192xi32>
1273+
// clang-format on
1274+
auto newDstType = convertTy(dstType);
1275+
if (!(mlir::isa<mlir::MemRefType>(adaptor.getSrc().getType()) &&
1276+
mlir::isa<mlir::MemRefType>(newDstType)))
1277+
return op.emitError() << "NYI bitcast from " << op.getSrc().getType()
1278+
<< " to " << dstType;
1279+
auto dstMR = mlir::cast<mlir::MemRefType>(newDstType);
1280+
auto [strides, offset] = dstMR.getStridesAndOffset();
1281+
rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>(
1282+
op, dstMR, src, offset, dstMR.getShape(), strides);
1283+
return mlir::success();
1284+
}
12661285
case CIR::int_to_bool: {
12671286
auto zero = rewriter.create<cir::ConstantOp>(
12681287
src.getLoc(), op.getSrc().getType(),

0 commit comments

Comments
 (0)