Skip to content

Commit 80dd7f0

Browse files
AmrDeveloperlanza
authored andcommitted
[CIR] Backport ComplexType support in CallExpr args (#1954)
Backport ComplexType support in CallExpr args from the upstream
1 parent d147555 commit 80dd7f0

File tree

3 files changed

+72
-8
lines changed

3 files changed

+72
-8
lines changed

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ void CIRGenModule::constructAttributeList(
274274
// TODO(cir): add alloc size attr.
275275
}
276276

277-
if (TargetDecl->hasAttr<DeviceKernelAttr>() && DeviceKernelAttr::isOpenCLSpelling(TargetDecl->getAttr<DeviceKernelAttr>())) {
277+
if (TargetDecl->hasAttr<DeviceKernelAttr>() &&
278+
DeviceKernelAttr::isOpenCLSpelling(
279+
TargetDecl->getAttr<DeviceKernelAttr>())) {
278280
auto cirKernelAttr = cir::OpenCLKernelAttr::get(&getMLIRContext());
279281
funcAttrs.set(cirKernelAttr.getMnemonic(), cirKernelAttr);
280282

@@ -476,7 +478,7 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo,
476478
I != E; ++I, ++type_it, ++ArgNo) {
477479

478480
mlir::Type argType = convertType(*type_it);
479-
if (!mlir::isa<cir::RecordType>(argType)) {
481+
if (!mlir::isa<cir::RecordType, cir::ComplexType>(argType)) {
480482
mlir::Value V;
481483
assert(!I->isAggregate() && "Aggregate NYI");
482484
V = I->getKnownRValue().getScalarVal();
@@ -496,16 +498,16 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo,
496498
// FIXME: Avoid the conversion through memory if possible.
497499
Address Src = Address::invalid();
498500
if (!I->isAggregate()) {
499-
llvm_unreachable("NYI");
501+
Src = CreateMemTemp(I->Ty, loc, "coerce");
502+
I->copyInto(*this, Src, loc);
500503
} else {
501504
Src = I->hasLValue() ? I->getKnownLValue().getAddress()
502505
: I->getKnownRValue().getAggregateAddress();
503506
}
504507

505508
// Fast-isel and the optimizer generally like scalar values better than
506509
// FCAs, so we flatten them if this is safe to do for this argument.
507-
auto STy = cast<cir::RecordType>(argType);
508-
auto SrcTy = Src.getElementType();
510+
auto srcTy = Src.getElementType();
509511
// FIXME(cir): get proper location for each argument.
510512
auto argLoc = loc;
511513

@@ -519,13 +521,13 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo,
519521
// uint64_t SrcSize = CGM.getDataLayout().getTypeAllocSize(SrcTy);
520522
// uint64_t DstSize = CGM.getDataLayout().getTypeAllocSize(STy);
521523
// if (SrcSize < DstSize) {
522-
if (SrcTy != STy)
524+
if (srcTy != argType)
523525
llvm_unreachable("NYI");
524526
else {
525527
// FIXME(cir): this currently only runs when the types are different,
526528
// but should be when alloc sizes are different, fix this as soon as
527529
// datalayout gets introduced.
528-
Src = builder.createElementBitCast(argLoc, Src, STy);
530+
Src = builder.createElementBitCast(argLoc, Src, argType);
529531
}
530532

531533
// assert(NumCIRArgs == STy.getMembers().size());
@@ -757,6 +759,18 @@ mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc,
757759
return call->getResult(0);
758760
}
759761

762+
void CallArg::copyInto(CIRGenFunction &cgf, Address addr,
763+
mlir::Location loc) const {
764+
LValue dst = cgf.makeAddrLValue(addr, Ty);
765+
if (!HasLV && RV.isScalar())
766+
llvm_unreachable("copyInto scalar value");
767+
else if (!HasLV && RV.isComplex())
768+
cgf.emitStoreOfComplex(loc, RV.getComplexVal(), dst, /*isInit=*/true);
769+
else
770+
llvm_unreachable("copyInto hasLV");
771+
IsUsed = true;
772+
}
773+
760774
void CIRGenFunction::emitCallArg(CallArgList &args, const Expr *E,
761775
QualType type) {
762776
// TODO: Add the DisableDebugLocationUpdates helper
@@ -982,7 +996,8 @@ static void appendParameterTypes(
982996
for (unsigned I = 0, E = FPT->getNumParams(); I != E; ++I) {
983997
prefix.push_back(FPT->getParamType(I));
984998
if (ExtInfos[I].hasPassObjectSize())
985-
prefix.push_back(CGT.getContext().getCanonicalType(CGT.getContext().getSizeType()));
999+
prefix.push_back(
1000+
CGT.getContext().getCanonicalType(CGT.getContext().getSizeType()));
9861001
}
9871002

9881003
addExtParameterInfosForCall(paramInfos, FPT.getTypePtr(), PrefixSize,

clang/lib/CIR/CodeGen/CIRGenCall.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ struct CallArg {
242242
}
243243

244244
bool isAggregate() const { return HasLV || RV.isAggregate(); }
245+
246+
void copyInto(CIRGenFunction &cgf, Address addr, mlir::Location loc) const;
245247
};
246248

247249
class CallArgList : public llvm::SmallVector<CallArg, 8> {

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,50 @@ void atomic_complex_type() {
284284
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[B_ADDR]], i32 0, i32 1
285285
// OGCG: store float %[[ATOMIC_TMP_REAL]], ptr %[[B_REAL_PTR]], align 4
286286
// OGCG: store float %[[ATOMIC_TMP_IMAG]], ptr %[[B_IMAG_PTR]], align 4
287+
288+
void complex_type_parameter(float _Complex a) {}
289+
290+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["a", init]
291+
// CIR: cir.store %{{.*}}, %[[A_ADDR]] : !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>
292+
293+
// TODO(CIR): the difference between the CIR LLVM and OGCG is because the lack of calling convention lowering,
294+
// Test will be updated when that is implemented
295+
296+
// LLVM: %[[A_ADDR:.*]] = alloca { float, float }, i64 1, align 4
297+
// LLVM: store { float, float } %{{.*}}, ptr %[[A_ADDR]], align 4
298+
299+
// OGCG: %[[A_ADDR:.*]] = alloca { float, float }, align 4
300+
// OGCG: store <2 x float> %a.coerce, ptr %[[A_ADDR]], align 4
301+
302+
void complex_type_argument() {
303+
float _Complex a;
304+
complex_type_parameter(a);
305+
}
306+
307+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["a"]
308+
// CIR: %[[ARG_ADDR:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["coerce"]
309+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.complex<!cir.float>>, !cir.complex<!cir.float>
310+
// CIR: cir.store{{.*}} %[[TMP_A]], %[[ARG_ADDR]] : !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>
311+
// CIR: %[[TMP_ARG:.*]] = cir.load{{.*}} %[[ARG_ADDR]] : !cir.ptr<!cir.complex<!cir.float>>, !cir.complex<!cir.float>
312+
// CIR: cir.call @_Z22complex_type_parameterCf(%[[TMP_ARG]]) : (!cir.complex<!cir.float>) -> ()
313+
314+
// LLVM: %[[A_ADDR:.*]] = alloca { float, float }, i64 1, align 4
315+
// LLVM: %[[ARG_ADDR:.*]] = alloca { float, float }, i64 1, align 4
316+
// LLVM: %[[TMP_A:.*]] = load { float, float }, ptr %[[A_ADDR]], align 4
317+
// LLVM: store { float, float } %[[TMP_A]], ptr %[[ARG_ADDR]], align 4
318+
// LLVM: %[[TMP_ARG:.*]] = load { float, float }, ptr %[[ARG_ADDR]], align 4
319+
// LLVM: call void @_Z22complex_type_parameterCf({ float, float } %[[TMP_ARG]])
320+
321+
// OGCG: %[[A_ADDR:.*]] = alloca { float, float }, align 4
322+
// OGCG: %[[ARG_ADDR:.*]] = alloca { float, float }, align 4
323+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[A_ADDR]], i32 0, i32 0
324+
// OGCG: %[[A_REAL:.*]] = load float, ptr %[[A_REAL_PTR]], align 4
325+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[A_ADDR]], i32 0, i32 1
326+
// OGCG: %[[A_IMAG:.*]] = load float, ptr %[[A_IMAG_PTR]], align 4
327+
// OGCG: %[[ARG_REAL_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[ARG_ADDR]], i32 0, i32 0
328+
// OGCG: %[[ARG_IMAG_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[ARG_ADDR]], i32 0, i32 1
329+
// OGCG: store float %[[A_REAL]], ptr %[[ARG_REAL_PTR]], align 4
330+
// OGCG: store float %[[A_IMAG]], ptr %[[ARG_IMAG_PTR]], align 4
331+
// OGCG: %[[TMP_ARG:.*]] = load <2 x float>, ptr %[[ARG_ADDR]], align 4
332+
// OGCG: call void @_Z22complex_type_parameterCf(<2 x float> noundef %[[TMP_ARG]])
333+

0 commit comments

Comments
 (0)