Skip to content

Commit e3d6193

Browse files
xlaukolanza
authored andcommitted
[CIR] Clean up ptr_stride and its nowrap flag (#1933)
1 parent 6a04fd6 commit e3d6193

File tree

4 files changed

+38
-31
lines changed

4 files changed

+38
-31
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,9 +683,9 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
683683

684684
cir::PtrStrideOp
685685
createPtrStride(mlir::Location loc, mlir::Value base, mlir::Value stride,
686-
std::optional<CIR_GEPNoWrapFlags> flags = std::nullopt) {
686+
std::optional<GEPNoWrapFlags> flags = std::nullopt) {
687687
return cir::PtrStrideOp::create(*this, loc, base.getType(), base, stride,
688-
flags.value_or(CIR_GEPNoWrapFlags::none));
688+
flags.value_or(GEPNoWrapFlags::none));
689689
}
690690

691691
cir::CallOp createCallOp(mlir::Location loc,

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -406,57 +406,66 @@ def CIR_PtrDiffOp : CIR_Op<"ptr_diff", [Pure, SameTypeOperands]> {
406406
//===----------------------------------------------------------------------===//
407407
// PtrStrideOp
408408
//===----------------------------------------------------------------------===//
409+
410+
// These mirror the GEPNoWrapFlags in LLVM IR Dialect.
409411
def CIR_GEPNone : I32BitEnumCaseNone<"none">;
410412
def CIR_GEPInboundsFlag : I32BitEnumCaseBit<"inboundsFlag", 0, "inbounds_flag">;
411413
def CIR_GEPNusw : I32BitEnumCaseBit<"nusw", 1>;
412414
def CIR_GEPNuw : I32BitEnumCaseBit<"nuw", 2>;
413-
def CIR_GEPInbounds
414-
: BitEnumCaseGroup<"inbounds", [CIR_GEPInboundsFlag, CIR_GEPNusw]>;
415-
416-
def CIR_GEPNoWrapFlags
417-
: CIR_I32BitEnum<"CIR_GEPNoWrapFlags", "::cir::CIR_GEPNoWrapFlags",
418-
[CIR_GEPNone, CIR_GEPInboundsFlag, CIR_GEPNusw, CIR_GEPNuw,
419-
CIR_GEPInbounds]> {
420-
let cppNamespace = "::cir";
415+
def CIR_GEPInbounds : BitEnumCaseGroup<"inbounds", [
416+
CIR_GEPInboundsFlag, CIR_GEPNusw]>;
417+
418+
def CIR_GEPNoWrapFlags : CIR_I32BitEnum<"GEPNoWrapFlags", "no-wrap flags", [
419+
CIR_GEPNone,
420+
CIR_GEPInboundsFlag,
421+
CIR_GEPNusw,
422+
CIR_GEPNuw,
423+
CIR_GEPInbounds
424+
]> {
421425
let printBitEnumPrimaryGroups = 1;
422426
}
423427

424428
def CIR_GEPNoWrapFlagsProp : EnumProp<CIR_GEPNoWrapFlags> {
425-
let defaultValue = interfaceType#"::none";
429+
let defaultValue = enum.cppType # "::" # "none";
426430
}
427431

428432
def CIR_PtrStrideOp : CIR_Op<"ptr_stride",[
429433
Pure, AllTypesMatch<["base", "result"]>
430434
]> {
431435
let summary = "Pointer access with stride";
432436
let description = [{
433-
Given a base pointer as first operand, provides a new pointer after applying
434-
a stride (second operand).
435-
436-
```mlir
437-
%3 = cir.const 0 : i32
438-
439-
%4 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32), !cir.ptr<i32>
437+
The `cir.ptr_stride` operation computes a new pointer from a base pointer
438+
and an integer stride, similar to a single-index `getelementptr` in LLVM IR.
439+
It moves the pointer by `stride * sizeof(element_type)` bytes.
440440

441-
%5 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32, inbounds), !cir.ptr<i32>
441+
Optional no-wrap flags refine pointer arithmetic semantics, that mirror
442+
LLVM's GEP no-wrap semantics.
442443

443-
%6 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32, inbounds|nuw), !cir.ptr<i32>
444+
Example:
444445

446+
```mlir
447+
%3 = cir.ptr_stride %1, %2 : (!cir.ptr<i32>, i32) ->!cir.ptr<i32>
448+
%4 = cir.ptr_stride inbounds %1, %2 : (!cir.ptr<i32>, i32) -> !cir.ptr<i32>
449+
%5 = cir.ptr_stride inbounds|nuw %1, %2 : (!cir.ptr<i32>, i32) -> !cir.ptr<i32>
445450
```
446451
}];
447452

448-
let arguments = (ins CIR_PointerType:$base, CIR_AnyFundamentalIntType:$stride,
449-
CIR_GEPNoWrapFlagsProp:$noWrapFlags);
453+
let arguments = (ins
454+
CIR_PointerType:$base,
455+
CIR_AnyFundamentalIntType:$stride,
456+
CIR_GEPNoWrapFlagsProp:$noWrapFlags
457+
);
450458

451459
let results = (outs CIR_PointerType:$result);
452460

453461
let assemblyFormat = [{
454-
($noWrapFlags^)? $base`,` $stride `:` functional-type(operands, results) attr-dict
462+
($noWrapFlags^)? $base`,` $stride `:` functional-type(operands, results)
463+
attr-dict
455464
}];
456465

457466
let extraClassDeclaration = [{
458467
// Get type pointed by the base pointer.
459-
mlir::Type getElementTy() {
468+
mlir::Type getElementType() {
460469
return getBase().getType().getPointee();
461470
}
462471
}];

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2853,9 +2853,9 @@ mlir::Value CIRGenFunction::emitCheckedInBoundsGEP(
28532853
builder.create<cir::PtrStrideOp>(CGM.getLoc(Loc), PtrTy, Ptr, IdxList[0]);
28542854
// If the pointer overflow sanitizer isn't enabled, do nothing.
28552855
if (!SanOpts.has(SanitizerKind::PointerOverflow)) {
2856-
cir::CIR_GEPNoWrapFlags nwFlags = cir::CIR_GEPNoWrapFlags::inbounds;
2856+
cir::GEPNoWrapFlags nwFlags = cir::GEPNoWrapFlags::inbounds;
28572857
if (!SignedIndices && !IsSubtraction)
2858-
nwFlags = nwFlags | cir::CIR_GEPNoWrapFlags::nuw;
2858+
nwFlags = nwFlags | cir::GEPNoWrapFlags::nuw;
28592859
return builder.create<cir::PtrStrideOp>(CGM.getLoc(Loc), PtrTy, Ptr,
28602860
IdxList[0], nwFlags);
28612861
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,15 @@ void walkRegionSkipping(mlir::Region &region,
9797

9898
/// Convert from a CIR PtrStrideOp kind to an LLVM IR equivalent of GEP.
9999
mlir::LLVM::GEPNoWrapFlags
100-
convertPtrStrideKindToGEPFlags(cir::CIR_GEPNoWrapFlags flags) {
101-
using CIRFlags = cir::CIR_GEPNoWrapFlags;
100+
convertPtrStrideKindToGEPFlags(cir::GEPNoWrapFlags flags) {
101+
using CIRFlags = cir::GEPNoWrapFlags;
102102
using LLVMFlags = mlir::LLVM::GEPNoWrapFlags;
103103

104104
LLVMFlags x = LLVMFlags::none;
105105
if ((flags & CIRFlags::inboundsFlag) == CIRFlags::inboundsFlag)
106106
x = x | LLVMFlags::inboundsFlag;
107107
if ((flags & CIRFlags::nusw) == CIRFlags::nusw)
108108
x = x | LLVMFlags::nusw;
109-
if ((flags & CIRFlags::inbounds) == CIRFlags::inbounds)
110-
x = x | LLVMFlags::inbounds;
111109
if ((flags & CIRFlags::nuw) == CIRFlags::nuw)
112110
x = x | LLVMFlags::nuw;
113111
return x;
@@ -1086,7 +1084,7 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
10861084
auto *tc = getTypeConverter();
10871085
const auto resultTy = tc->convertType(ptrStrideOp.getType());
10881086
auto elementTy =
1089-
convertTypeForMemory(*tc, dataLayout, ptrStrideOp.getElementTy());
1087+
convertTypeForMemory(*tc, dataLayout, ptrStrideOp.getElementType());
10901088
auto *ctx = elementTy.getContext();
10911089

10921090
// void and function types doesn't really have a layout to use in GEPs,

0 commit comments

Comments
 (0)