Skip to content

Commit c1df45c

Browse files
authored
[CIR] Clean up ptr_stride and its nowrap flag (#1933)
1 parent 6b661d5 commit c1df45c

File tree

4 files changed

+38
-31
lines changed

4 files changed

+38
-31
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,9 +683,9 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
683683

684684
cir::PtrStrideOp
685685
createPtrStride(mlir::Location loc, mlir::Value base, mlir::Value stride,
686-
std::optional<CIR_GEPNoWrapFlags> flags = std::nullopt) {
686+
std::optional<GEPNoWrapFlags> flags = std::nullopt) {
687687
return cir::PtrStrideOp::create(*this, loc, base.getType(), base, stride,
688-
flags.value_or(CIR_GEPNoWrapFlags::none));
688+
flags.value_or(GEPNoWrapFlags::none));
689689
}
690690

691691
cir::CallOp createCallOp(mlir::Location loc,

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -401,57 +401,66 @@ def CIR_PtrDiffOp : CIR_Op<"ptr_diff", [Pure, SameTypeOperands]> {
401401
//===----------------------------------------------------------------------===//
402402
// PtrStrideOp
403403
//===----------------------------------------------------------------------===//
404+
405+
// These mirror the GEPNoWrapFlags in LLVM IR Dialect.
404406
def CIR_GEPNone : I32BitEnumCaseNone<"none">;
405407
def CIR_GEPInboundsFlag : I32BitEnumCaseBit<"inboundsFlag", 0, "inbounds_flag">;
406408
def CIR_GEPNusw : I32BitEnumCaseBit<"nusw", 1>;
407409
def CIR_GEPNuw : I32BitEnumCaseBit<"nuw", 2>;
408-
def CIR_GEPInbounds
409-
: BitEnumCaseGroup<"inbounds", [CIR_GEPInboundsFlag, CIR_GEPNusw]>;
410-
411-
def CIR_GEPNoWrapFlags
412-
: CIR_I32BitEnum<"CIR_GEPNoWrapFlags", "::cir::CIR_GEPNoWrapFlags",
413-
[CIR_GEPNone, CIR_GEPInboundsFlag, CIR_GEPNusw, CIR_GEPNuw,
414-
CIR_GEPInbounds]> {
415-
let cppNamespace = "::cir";
410+
def CIR_GEPInbounds : BitEnumCaseGroup<"inbounds", [
411+
CIR_GEPInboundsFlag, CIR_GEPNusw]>;
412+
413+
def CIR_GEPNoWrapFlags : CIR_I32BitEnum<"GEPNoWrapFlags", "no-wrap flags", [
414+
CIR_GEPNone,
415+
CIR_GEPInboundsFlag,
416+
CIR_GEPNusw,
417+
CIR_GEPNuw,
418+
CIR_GEPInbounds
419+
]> {
416420
let printBitEnumPrimaryGroups = 1;
417421
}
418422

419423
def CIR_GEPNoWrapFlagsProp : EnumProp<CIR_GEPNoWrapFlags> {
420-
let defaultValue = interfaceType#"::none";
424+
let defaultValue = enum.cppType # "::" # "none";
421425
}
422426

423427
def CIR_PtrStrideOp : CIR_Op<"ptr_stride",[
424428
Pure, AllTypesMatch<["base", "result"]>
425429
]> {
426430
let summary = "Pointer access with stride";
427431
let description = [{
428-
Given a base pointer as first operand, provides a new pointer after applying
429-
a stride (second operand).
430-
431-
```mlir
432-
%3 = cir.const 0 : i32
433-
434-
%4 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32), !cir.ptr<i32>
432+
The `cir.ptr_stride` operation computes a new pointer from a base pointer
433+
and an integer stride, similar to a single-index `getelementptr` in LLVM IR.
434+
It moves the pointer by `stride * sizeof(element_type)` bytes.
435435

436-
%5 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32, inbounds), !cir.ptr<i32>
436+
Optional no-wrap flags refine pointer arithmetic semantics, that mirror
437+
LLVM's GEP no-wrap semantics.
437438

438-
%6 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32, inbounds|nuw), !cir.ptr<i32>
439+
Example:
439440

441+
```mlir
442+
%3 = cir.ptr_stride %1, %2 : (!cir.ptr<i32>, i32) ->!cir.ptr<i32>
443+
%4 = cir.ptr_stride inbounds %1, %2 : (!cir.ptr<i32>, i32) -> !cir.ptr<i32>
444+
%5 = cir.ptr_stride inbounds|nuw %1, %2 : (!cir.ptr<i32>, i32) -> !cir.ptr<i32>
440445
```
441446
}];
442447

443-
let arguments = (ins CIR_PointerType:$base, CIR_AnyFundamentalIntType:$stride,
444-
CIR_GEPNoWrapFlagsProp:$noWrapFlags);
448+
let arguments = (ins
449+
CIR_PointerType:$base,
450+
CIR_AnyFundamentalIntType:$stride,
451+
CIR_GEPNoWrapFlagsProp:$noWrapFlags
452+
);
445453

446454
let results = (outs CIR_PointerType:$result);
447455

448456
let assemblyFormat = [{
449-
($noWrapFlags^)? $base`,` $stride `:` functional-type(operands, results) attr-dict
457+
($noWrapFlags^)? $base`,` $stride `:` functional-type(operands, results)
458+
attr-dict
450459
}];
451460

452461
let extraClassDeclaration = [{
453462
// Get type pointed by the base pointer.
454-
mlir::Type getElementTy() {
463+
mlir::Type getElementType() {
455464
return getBase().getType().getPointee();
456465
}
457466
}];

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2850,9 +2850,9 @@ mlir::Value CIRGenFunction::emitCheckedInBoundsGEP(
28502850
builder.create<cir::PtrStrideOp>(CGM.getLoc(Loc), PtrTy, Ptr, IdxList[0]);
28512851
// If the pointer overflow sanitizer isn't enabled, do nothing.
28522852
if (!SanOpts.has(SanitizerKind::PointerOverflow)) {
2853-
cir::CIR_GEPNoWrapFlags nwFlags = cir::CIR_GEPNoWrapFlags::inbounds;
2853+
cir::GEPNoWrapFlags nwFlags = cir::GEPNoWrapFlags::inbounds;
28542854
if (!SignedIndices && !IsSubtraction)
2855-
nwFlags = nwFlags | cir::CIR_GEPNoWrapFlags::nuw;
2855+
nwFlags = nwFlags | cir::GEPNoWrapFlags::nuw;
28562856
return builder.create<cir::PtrStrideOp>(CGM.getLoc(Loc), PtrTy, Ptr,
28572857
IdxList[0], nwFlags);
28582858
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,15 @@ void walkRegionSkipping(mlir::Region &region,
9494

9595
/// Convert from a CIR PtrStrideOp kind to an LLVM IR equivalent of GEP.
9696
mlir::LLVM::GEPNoWrapFlags
97-
convertPtrStrideKindToGEPFlags(cir::CIR_GEPNoWrapFlags flags) {
98-
using CIRFlags = cir::CIR_GEPNoWrapFlags;
97+
convertPtrStrideKindToGEPFlags(cir::GEPNoWrapFlags flags) {
98+
using CIRFlags = cir::GEPNoWrapFlags;
9999
using LLVMFlags = mlir::LLVM::GEPNoWrapFlags;
100100

101101
LLVMFlags x = LLVMFlags::none;
102102
if ((flags & CIRFlags::inboundsFlag) == CIRFlags::inboundsFlag)
103103
x = x | LLVMFlags::inboundsFlag;
104104
if ((flags & CIRFlags::nusw) == CIRFlags::nusw)
105105
x = x | LLVMFlags::nusw;
106-
if ((flags & CIRFlags::inbounds) == CIRFlags::inbounds)
107-
x = x | LLVMFlags::inbounds;
108106
if ((flags & CIRFlags::nuw) == CIRFlags::nuw)
109107
x = x | LLVMFlags::nuw;
110108
return x;
@@ -1021,7 +1019,7 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
10211019
auto *tc = getTypeConverter();
10221020
const auto resultTy = tc->convertType(ptrStrideOp.getType());
10231021
auto elementTy =
1024-
convertTypeForMemory(*tc, dataLayout, ptrStrideOp.getElementTy());
1022+
convertTypeForMemory(*tc, dataLayout, ptrStrideOp.getElementType());
10251023
auto *ctx = elementTy.getContext();
10261024

10271025
// void and function types doesn't really have a layout to use in GEPs,

0 commit comments

Comments
 (0)