Skip to content

Commit 75293db

Browse files
xlaukolanza
authored andcommitted
[CIR] Remove inferred context from pointer type getters (llvm#1600)
1 parent 544093d commit 75293db

File tree

12 files changed

+38
-53
lines changed

12 files changed

+38
-53
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
101101

102102
cir::PointerType getPointerTo(mlir::Type ty,
103103
cir::AddressSpaceAttr cirAS = {}) {
104-
return cir::PointerType::get(getContext(), ty, cirAS);
104+
return cir::PointerType::get(ty, cirAS);
105105
}
106106

107107
cir::PointerType getPointerTo(mlir::Type ty, clang::LangAS langAS) {
@@ -507,7 +507,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
507507
mlir::cast<cir::PointerType>(recordPtr.getType()).getPointee();
508508
assert(mlir::isa<cir::RecordType>(recordBaseTy));
509509
auto fldTy = mlir::cast<cir::RecordType>(recordBaseTy).getMembers()[idx];
510-
auto fldPtrTy = cir::PointerType::get(getContext(), fldTy);
510+
auto fldPtrTy = cir::PointerType::get(fldTy);
511511
return create<cir::GetMemberOp>(loc, fldPtrTy, recordPtr, fldName, idx);
512512
}
513513

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
444444
return typeCache.UInt8PtrTy;
445445
}
446446
cir::PointerType getUInt32PtrTy(unsigned AddrSpace = 0) {
447-
return cir::PointerType::get(getContext(), typeCache.UInt32Ty);
447+
return cir::PointerType::get(typeCache.UInt32Ty);
448448
}
449449

450450
/// Get a CIR anonymous record type.

clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,16 +216,14 @@ void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
216216
auto kernel = [&]() {
217217
if (auto globalOp = llvm::dyn_cast_or_null<cir::GlobalOp>(
218218
KernelHandles[fn.getSymName()])) {
219-
auto kernelTy =
220-
cir::PointerType::get(&cgm.getMLIRContext(), globalOp.getSymType());
219+
auto kernelTy = cir::PointerType::get(globalOp.getSymType());
221220
mlir::Value kernel = builder.create<cir::GetGlobalOp>(
222221
loc, kernelTy, globalOp.getSymName());
223222
return kernel;
224223
}
225224
if (auto funcOp = llvm::dyn_cast_or_null<cir::FuncOp>(
226225
KernelHandles[fn.getSymName()])) {
227-
auto kernelTy = cir::PointerType::get(&cgm.getMLIRContext(),
228-
funcOp.getFunctionType());
226+
auto kernelTy = cir::PointerType::get(funcOp.getFunctionType());
229227
mlir::Value kernel =
230228
builder.create<cir::GetGlobalOp>(loc, kernelTy, funcOp.getSymName());
231229
mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ static Address emitAddrOfFieldStorage(CIRGenFunction &CGF, Address Base,
7575
auto loc = CGF.getLoc(field->getLocation());
7676

7777
auto fieldType = CGF.convertType(field->getType());
78-
auto fieldPtr =
79-
cir::PointerType::get(CGF.getBuilder().getContext(), fieldType);
78+
auto fieldPtr = cir::PointerType::get(fieldType);
8079
// For most cases fieldName is the same as field->getName() but for lambdas,
8180
// which do not currently carry the name, so it can be passed down from the
8281
// CaptureStmt.
@@ -262,7 +261,7 @@ Address CIRGenFunction::getAddrOfBitFieldStorage(LValue base,
262261
if (index == 0)
263262
return base.getAddress();
264263
auto loc = getLoc(field->getLocation());
265-
auto fieldPtr = cir::PointerType::get(getBuilder().getContext(), fieldType);
264+
auto fieldPtr = cir::PointerType::get(fieldType);
266265
auto sea = getBuilder().createGetMember(loc, fieldPtr, base.getPointer(),
267266
field->getName(), index);
268267
return Address(sea, CharUnits::One());
@@ -981,13 +980,13 @@ static LValue emitFunctionDeclLValue(CIRGenFunction &CGF, const Expr *E,
981980
CharUnits align = CGF.getContext().getDeclAlign(FD);
982981

983982
mlir::Type fnTy = funcOp.getFunctionType();
984-
auto ptrTy = cir::PointerType::get(CGF.getBuilder().getContext(), fnTy);
983+
auto ptrTy = cir::PointerType::get(fnTy);
985984
mlir::Value addr = CGF.getBuilder().create<cir::GetGlobalOp>(
986985
loc, ptrTy, funcOp.getSymName());
987986

988987
if (funcOp.getFunctionType() != CGF.convertType(FD->getType())) {
989988
fnTy = CGF.convertType(FD->getType());
990-
ptrTy = cir::PointerType::get(CGF.getBuilder().getContext(), fnTy);
989+
ptrTy = cir::PointerType::get(fnTy);
991990

992991
addr = CGF.getBuilder().create<cir::CastOp>(addr.getLoc(), ptrTy,
993992
cir::CastKind::bitcast, addr);
@@ -1538,15 +1537,14 @@ RValue CIRGenFunction::emitCall(clang::QualType CalleeType,
15381537
// get non-variadic function type
15391538
CalleeTy = cir::FuncType::get(CalleeTy.getInputs(),
15401539
CalleeTy.getReturnType(), false);
1541-
auto CalleePtrTy = cir::PointerType::get(&getMLIRContext(), CalleeTy);
1540+
auto CalleePtrTy = cir::PointerType::get(CalleeTy);
15421541

15431542
auto *Fn = Callee.getFunctionPointer();
15441543
mlir::Value Addr;
15451544
if (auto funcOp = llvm::dyn_cast<cir::FuncOp>(Fn)) {
15461545
Addr = builder.create<cir::GetGlobalOp>(
15471546
getLoc(E->getSourceRange()),
1548-
cir::PointerType::get(&getMLIRContext(), funcOp.getFunctionType()),
1549-
funcOp.getSymName());
1547+
cir::PointerType::get(funcOp.getFunctionType()), funcOp.getSymName());
15501548
} else {
15511549
Addr = Fn->getResult(0);
15521550
}
@@ -2954,7 +2952,7 @@ mlir::Value CIRGenFunction::emitLoadOfScalar(Address addr, bool isVolatile,
29542952
auto Ptr = addr.getPointer();
29552953
if (mlir::isa<cir::VoidType>(eltTy)) {
29562954
eltTy = cir::IntType::get(&getMLIRContext(), 8, true);
2957-
auto ElemPtrTy = cir::PointerType::get(&getMLIRContext(), eltTy);
2955+
auto ElemPtrTy = cir::PointerType::get(eltTy);
29582956
Ptr = builder.create<cir::CastOp>(loc, ElemPtrTy, cir::CastKind::bitcast,
29592957
Ptr);
29602958
}

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &mlirContext,
138138
VoidTy = cir::VoidType::get(&getMLIRContext());
139139

140140
// Initialize CIR pointer types cache.
141-
VoidPtrTy = cir::PointerType::get(&getMLIRContext(), VoidTy);
142-
VoidPtrPtrTy = cir::PointerType::get(&getMLIRContext(), VoidPtrTy);
141+
VoidPtrTy = cir::PointerType::get(VoidTy);
142+
VoidPtrPtrTy = cir::PointerType::get(VoidPtrTy);
143143

144144
FP16Ty = cir::FP16Type::get(&getMLIRContext());
145145
BFloat16Ty = cir::BF16Type::get(&getMLIRContext());
@@ -948,15 +948,14 @@ static GlobalViewAttr createNewGlobalView(CIRGenModule &cgm, GlobalOp newGlob,
948948
llvm::SmallVector<int64_t> newInds;
949949
CIRGenBuilderTy &bld = cgm.getBuilder();
950950
const CIRDataLayout &layout = cgm.getDataLayout();
951-
mlir::MLIRContext *ctxt = bld.getContext();
952951
auto newTy = newGlob.getSymType();
953952

954953
auto offset = bld.computeOffsetFromGlobalViewIndices(layout, oldTy, oldInds);
955954
bld.computeGlobalViewIndicesFromFlatOffset(offset, newTy, layout, newInds);
956955
cir::PointerType newPtrTy;
957956

958957
if (isa<cir::RecordType>(oldTy))
959-
newPtrTy = cir::PointerType::get(ctxt, newTy);
958+
newPtrTy = cir::PointerType::get(newTy);
960959
else if (cir::ArrayType oldArTy = dyn_cast<cir::ArrayType>(oldTy))
961960
newPtrTy = dyn_cast<cir::PointerType>(attr.getType());
962961

@@ -1015,8 +1014,7 @@ void CIRGenModule::replaceGlobal(cir::GlobalOp oldSym, cir::GlobalOp newSym) {
10151014

10161015
if (auto ggo = dyn_cast<cir::GetGlobalOp>(use.getUser())) {
10171016
auto useOpResultValue = ggo.getAddr();
1018-
useOpResultValue.setType(
1019-
cir::PointerType::get(&getMLIRContext(), newTy));
1017+
useOpResultValue.setType(cir::PointerType::get(newTy));
10201018

10211019
mlir::OpBuilder::InsertionGuard guard(builder);
10221020
builder.setInsertionPointAfter(ggo);
@@ -1470,8 +1468,7 @@ void CIRGenModule::emitGlobalVarDefinition(const clang::VarDecl *d,
14701468
// TODO(cir): pointer to array decay. Should this be modeled explicitly in
14711469
// CIR?
14721470
if (arrayTy)
1473-
initType =
1474-
cir::PointerType::get(&getMLIRContext(), arrayTy.getElementType());
1471+
initType = cir::PointerType::get(arrayTy.getElementType());
14751472
} else {
14761473
assert(mlir::isa<mlir::TypedAttr>(init) && "This should have a type");
14771474
auto typedInitAttr = mlir::cast<mlir::TypedAttr>(init);
@@ -2393,7 +2390,7 @@ void CIRGenModule::ReplaceUsesOfNonProtoTypeWithRealFunction(
23932390
} else if (auto getGlobalOp = dyn_cast<cir::GetGlobalOp>(use.getUser())) {
23942391
// Replace type
23952392
getGlobalOp.getAddr().setType(
2396-
cir::PointerType::get(&getMLIRContext(), newFn.getFunctionType()));
2393+
cir::PointerType::get(newFn.getFunctionType()));
23972394
} else {
23982395
llvm_unreachable("NIY");
23992396
}

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2210,8 +2210,7 @@ LogicalResult cir::VTableAddrPointOp::verify() {
22102210
auto intTy = cir::IntType::get(getContext(), 32, /*isSigned=*/false);
22112211
auto fnTy = cir::FuncType::get({}, intTy);
22122212

2213-
auto resTy = cir::PointerType::get(getContext(),
2214-
cir::PointerType::get(getContext(), fnTy));
2213+
auto resTy = cir::PointerType::get(cir::PointerType::get(fnTy));
22152214

22162215
if (resultType != resTy)
22172216
return emitOpError("result type must be '")
@@ -2258,8 +2257,8 @@ LogicalResult cir::VTTAddrPointOp::verify() {
22582257
auto resultType = getAddr().getType();
22592258

22602259
auto resTy = cir::PointerType::get(
2261-
getContext(),
2262-
cir::PointerType::get(getContext(), cir::VoidType::get(getContext())));
2260+
2261+
cir::PointerType::get(cir::VoidType::get(getContext())));
22632262

22642263
if (resultType != resTy)
22652264
return emitOpError("result type must be '")

clang/lib/CIR/Dialect/Transforms/CallConvLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ struct CallConvLowering {
6969

7070
mlir::Type convert(mlir::Type t) {
7171
if (auto fTy = getFuncPointerTy(t))
72-
return PointerType::get(rewriter.getContext(), convert(fTy));
72+
return cir::PointerType::get(convert(fTy));
7373
return t;
7474
}
7575

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,10 @@ FuncOp LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(GlobalOp op) {
287287

288288
// Create a runtime helper function:
289289
// extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d);
290-
auto voidPtrTy = cir::PointerType::get(builder.getContext(), voidTy);
290+
auto voidPtrTy = cir::PointerType::get(voidTy);
291291
auto voidFnTy = cir::FuncType::get({voidPtrTy}, voidTy);
292-
auto voidFnPtrTy = cir::PointerType::get(builder.getContext(), voidFnTy);
293-
auto HandlePtrTy =
294-
cir::PointerType::get(builder.getContext(), Handle.getSymType());
292+
auto voidFnPtrTy = cir::PointerType::get(voidFnTy);
293+
auto HandlePtrTy = cir::PointerType::get(Handle.getSymType());
295294
auto fnAtExitType =
296295
cir::FuncType::get({voidFnPtrTy, voidPtrTy, HandlePtrTy},
297296
cir::VoidType::get(builder.getContext()));
@@ -303,8 +302,7 @@ FuncOp LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(GlobalOp op) {
303302
// &__dso_handle)
304303
builder.setInsertionPointAfter(dtorCall);
305304
mlir::Value args[3];
306-
auto dtorPtrTy =
307-
cir::PointerType::get(builder.getContext(), dtorFunc.getFunctionType());
305+
auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType());
308306
// dtorPtrTy
309307
args[0] = builder.create<cir::GetGlobalOp>(dtorCall.getLoc(), dtorPtrTy,
310308
dtorFunc.getSymName());

clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,9 @@ void ItaniumCXXABI::lowerGetMethod(
375375
// Load vtable pointer.
376376
// Note that vtable pointer always point to the global address space.
377377
auto vtablePtrTy = cir::PointerType::get(
378-
rewriter.getContext(),
379378
cir::IntType::get(rewriter.getContext(), 8, true));
380379
auto vtablePtrPtrTy = cir::PointerType::get(
381-
rewriter.getContext(), vtablePtrTy,
380+
vtablePtrTy,
382381
mlir::cast<cir::PointerType>(op.getObject().getType()).getAddrSpace());
383382
auto vtablePtrPtr = rewriter.create<cir::CastOp>(
384383
op.getLoc(), vtablePtrPtrTy, cir::CastKind::bitcast, loweredObjectPtr);
@@ -413,8 +412,7 @@ void ItaniumCXXABI::lowerGetMethod(
413412
else {
414413
mlir::Value vfpAddr = rewriter.create<cir::PtrStrideOp>(
415414
op.getLoc(), vtablePtrTy, vtablePtr, vtableOffset);
416-
auto vfpPtrTy =
417-
cir::PointerType::get(rewriter.getContext(), calleePtrTy);
415+
auto vfpPtrTy = cir::PointerType::get(calleePtrTy);
418416
mlir::Value vfpPtr = rewriter.create<cir::CastOp>(
419417
op.getLoc(), vfpPtrTy, cir::CastKind::bitcast, vfpAddr);
420418
funcPtr = rewriter.create<cir::LoadOp>(

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ mlir::Value buildAddressAtOffset(LowerFunction &LF, mlir::Value addr,
4343

4444
mlir::Value createCoercedBitcast(mlir::Value Src, mlir::Type DestTy,
4545
LowerFunction &CGF) {
46-
auto destPtrTy = PointerType::get(CGF.getRewriter().getContext(), DestTy);
46+
auto destPtrTy = cir::PointerType::get(DestTy);
4747

4848
if (auto load = mlir::dyn_cast<LoadOp>(Src.getDefiningOp()))
4949
return CGF.getRewriter().create<CastOp>(Src.getLoc(), destPtrTy,
@@ -86,8 +86,7 @@ mlir::Value enterRecordPointerForCoercedAccess(mlir::Value SrcPtr,
8686
return SrcPtr;
8787

8888
auto &rw = CGF.getRewriter();
89-
auto *ctxt = rw.getContext();
90-
auto ptrTy = PointerType::get(ctxt, FirstElt);
89+
auto ptrTy = PointerType::get(FirstElt);
9190
if (mlir::isa<RecordType>(SrcPtr.getType())) {
9291
auto addr = SrcPtr;
9392
if (auto load = mlir::dyn_cast<LoadOp>(SrcPtr.getDefiningOp()))
@@ -168,7 +167,6 @@ static mlir::Value coerceIntOrPtrToIntOrPtr(mlir::Value val, mlir::Type typ,
168167

169168
AllocaOp createTmpAlloca(LowerFunction &LF, mlir::Location loc, mlir::Type ty) {
170169
auto &rw = LF.getRewriter();
171-
auto *ctxt = rw.getContext();
172170
mlir::PatternRewriter::InsertionGuard guard(rw);
173171

174172
// find function's entry block and use it to find a best place for alloca
@@ -184,7 +182,7 @@ AllocaOp createTmpAlloca(LowerFunction &LF, mlir::Location loc, mlir::Type ty) {
184182

185183
auto align = LF.LM.getDataLayout().getABITypeAlign(ty);
186184
auto alignAttr = rw.getI64IntegerAttr(align.value());
187-
auto ptrTy = PointerType::get(ctxt, ty);
185+
auto ptrTy = PointerType::get(ty);
188186
return rw.create<AllocaOp>(loc, ptrTy, ty, "tmp", alignAttr);
189187
}
190188

@@ -201,7 +199,7 @@ MemCpyOp createMemCpy(LowerFunction &LF, mlir::Value dst, mlir::Value src,
201199

202200
auto *ctxt = LF.getRewriter().getContext();
203201
auto &rw = LF.getRewriter();
204-
auto voidPtr = PointerType::get(ctxt, cir::VoidType::get(ctxt));
202+
auto voidPtr = PointerType::get(cir::VoidType::get(ctxt));
205203

206204
if (!isVoidPtr(src))
207205
src = createBitcast(src, voidPtr, LF);
@@ -286,7 +284,7 @@ void createCoercedStore(mlir::Value Src, mlir::Value Dst, bool DstIsVolatile,
286284
auto *ctxt = CGF.LM.getMLIRContext();
287285
auto dstIntTy = IntType::get(ctxt, DstSize.getFixedValue() * 8, false);
288286
Src = coerceIntOrPtrToIntOrPtr(Src, dstIntTy, CGF);
289-
auto ptrTy = PointerType::get(ctxt, dstIntTy);
287+
auto ptrTy = PointerType::get(dstIntTy);
290288
auto addr = bld.create<CastOp>(Dst.getLoc(), ptrTy, CastKind::bitcast, Dst);
291289
bld.create<StoreOp>(Dst.getLoc(), Src, addr);
292290
} else {
@@ -1257,7 +1255,7 @@ mlir::Value LowerFunction::rewriteCallOp(const LowerFunctionInfo &CallInfo,
12571255
if (Caller.isIndirect()) {
12581256
rewriter.setInsertionPoint(Caller);
12591257
auto val = Caller.getIndirectCall();
1260-
auto ptrTy = PointerType::get(val.getContext(), IRFuncTy);
1258+
auto ptrTy = PointerType::get(IRFuncTy);
12611259
auto callee =
12621260
rewriter.create<CastOp>(val.getLoc(), ptrTy, CastKind::bitcast, val);
12631261
newCallOp = rewriter.create<CallOp>(loc, callee, IRFuncTy, IRCallArgs);

0 commit comments

Comments
 (0)