Skip to content

Commit 9216e17

Browse files
authored
[CIR] Upstream basic support for ExtVector element expr (llvm#167570)
Upstream the basic support for the ExtVectorType element expr
1 parent 98f9b54 commit 9216e17

File tree

6 files changed

+167
-1
lines changed

6 files changed

+167
-1
lines changed

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,10 +631,49 @@ RValue CIRGenFunction::emitLoadOfLValue(LValue lv, SourceLocation loc) {
631631
lv.getVectorIdx()));
632632
}
633633

634+
if (lv.isExtVectorElt())
635+
return emitLoadOfExtVectorElementLValue(lv);
636+
634637
cgm.errorNYI(loc, "emitLoadOfLValue");
635638
return RValue::get(nullptr);
636639
}
637640

641+
int64_t CIRGenFunction::getAccessedFieldNo(unsigned int idx,
642+
const mlir::ArrayAttr elts) {
643+
auto elt = mlir::cast<mlir::IntegerAttr>(elts[idx]);
644+
return elt.getInt();
645+
}
646+
647+
// If this is a reference to a subset of the elements of a vector, create an
648+
// appropriate shufflevector.
649+
RValue CIRGenFunction::emitLoadOfExtVectorElementLValue(LValue lv) {
650+
mlir::Location loc = lv.getExtVectorPointer().getLoc();
651+
mlir::Value vec = builder.createLoad(loc, lv.getExtVectorAddress());
652+
653+
// HLSL allows treating scalars as one-element vectors. Converting the scalar
654+
// IR value to a vector here allows the rest of codegen to behave as normal.
655+
if (getLangOpts().HLSL && !mlir::isa<cir::VectorType>(vec.getType())) {
656+
cgm.errorNYI(loc, "emitLoadOfExtVectorElementLValue: HLSL");
657+
return {};
658+
}
659+
660+
const mlir::ArrayAttr elts = lv.getExtVectorElts();
661+
662+
// If the result of the expression is a non-vector type, we must be extracting
663+
// a single element. Just codegen as an extractelement.
664+
const auto *exprVecTy = lv.getType()->getAs<clang::VectorType>();
665+
if (!exprVecTy) {
666+
int64_t indexValue = getAccessedFieldNo(0, elts);
667+
cir::ConstantOp index =
668+
builder.getConstInt(loc, builder.getSInt64Ty(), indexValue);
669+
return RValue::get(cir::VecExtractOp::create(builder, loc, vec, index));
670+
}
671+
672+
cgm.errorNYI(
673+
loc, "emitLoadOfExtVectorElementLValue: Result of expr is vector type");
674+
return {};
675+
}
676+
638677
static cir::FuncOp emitFunctionDeclPointer(CIRGenModule &cgm, GlobalDecl gd) {
639678
assert(!cir::MissingFeatures::weakRefReference());
640679
return cgm.getAddrOfFunction(gd);
@@ -1120,6 +1159,46 @@ CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
11201159
return lv;
11211160
}
11221161

1162+
LValue CIRGenFunction::emitExtVectorElementExpr(const ExtVectorElementExpr *e) {
1163+
// Emit the base vector as an l-value.
1164+
LValue base;
1165+
1166+
// ExtVectorElementExpr's base can either be a vector or pointer to vector.
1167+
if (e->isArrow()) {
1168+
cgm.errorNYI(e->getSourceRange(),
1169+
"emitExtVectorElementExpr: pointer to vector");
1170+
return {};
1171+
} else if (e->getBase()->isGLValue()) {
1172+
// Otherwise, if the base is an lvalue ( as in the case of foo.x.x),
1173+
// emit the base as an lvalue.
1174+
assert(e->getBase()->getType()->isVectorType());
1175+
base = emitLValue(e->getBase());
1176+
} else {
1177+
// Otherwise, the base is a normal rvalue (as in (V+V).x), emit it as such.
1178+
cgm.errorNYI(e->getSourceRange(),
1179+
"emitExtVectorElementExpr: base is a normal rvalue");
1180+
return {};
1181+
}
1182+
1183+
QualType type =
1184+
e->getType().withCVRQualifiers(base.getQuals().getCVRQualifiers());
1185+
1186+
// Encode the element access list into a vector of unsigned indices.
1187+
SmallVector<uint32_t, 4> indices;
1188+
e->getEncodedElementAccess(indices);
1189+
1190+
if (base.isSimple()) {
1191+
SmallVector<int64_t> attrElts(indices.begin(), indices.end());
1192+
mlir::ArrayAttr elts = builder.getI64ArrayAttr(attrElts);
1193+
return LValue::makeExtVectorElt(base.getAddress(), elts, type,
1194+
base.getBaseInfo());
1195+
}
1196+
1197+
cgm.errorNYI(e->getSourceRange(),
1198+
"emitExtVectorElementExpr: isSimple is false");
1199+
return {};
1200+
}
1201+
11231202
LValue CIRGenFunction::emitStringLiteralLValue(const StringLiteral *e,
11241203
llvm::StringRef name) {
11251204
cir::GlobalOp globalOp = cgm.getGlobalForStringLiteral(e, name);

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
283283
e->getSourceRange().getBegin());
284284
}
285285

286+
mlir::Value VisitExtVectorElementExpr(Expr *e) { return emitLoadOfLValue(e); }
287+
286288
mlir::Value VisitMemberExpr(MemberExpr *e);
287289

288290
mlir::Value VisitCompoundLiteralExpr(CompoundLiteralExpr *e) {

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,8 @@ LValue CIRGenFunction::emitLValue(const Expr *e) {
887887
return emitConditionalOperatorLValue(cast<BinaryConditionalOperator>(e));
888888
case Expr::ArraySubscriptExprClass:
889889
return emitArraySubscriptExpr(cast<ArraySubscriptExpr>(e));
890+
case Expr::ExtVectorElementExprClass:
891+
return emitExtVectorElementExpr(cast<ExtVectorElementExpr>(e));
890892
case Expr::UnaryOperatorClass:
891893
return emitUnaryOpLValue(cast<UnaryOperator>(e));
892894
case Expr::StringLiteralClass:

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,8 @@ class CIRGenFunction : public CIRGenTypeCache {
12771277
QualType &baseType, Address &addr);
12781278
LValue emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e);
12791279

1280+
LValue emitExtVectorElementExpr(const ExtVectorElementExpr *e);
1281+
12801282
Address emitArrayToPointerDecay(const Expr *e,
12811283
LValueBaseInfo *baseInfo = nullptr);
12821284

@@ -1342,6 +1344,8 @@ class CIRGenFunction : public CIRGenTypeCache {
13421344
mlir::Value emittedE,
13431345
bool isDynamic);
13441346

1347+
int64_t getAccessedFieldNo(unsigned idx, mlir::ArrayAttr elts);
1348+
13451349
RValue emitCall(const CIRGenFunctionInfo &funcInfo,
13461350
const CIRGenCallee &callee, ReturnValueSlot returnValue,
13471351
const CallArgList &args, cir::CIRCallOpInterface *callOp,
@@ -1637,6 +1641,8 @@ class CIRGenFunction : public CIRGenTypeCache {
16371641
/// Load a complex number from the specified l-value.
16381642
mlir::Value emitLoadOfComplex(LValue src, SourceLocation loc);
16391643

1644+
RValue emitLoadOfExtVectorElementLValue(LValue lv);
1645+
16401646
/// Given an expression that represents a value lvalue, this method emits
16411647
/// the address of the lvalue, then loads the result as an rvalue,
16421648
/// returning the rvalue.

clang/lib/CIR/CodeGen/CIRGenValue.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ class LValue {
166166
// this is the alignment of the whole vector)
167167
unsigned alignment;
168168
mlir::Value v;
169-
mlir::Value vectorIdx; // Index for vector subscript
169+
mlir::Value vectorIdx; // Index for vector subscript
170+
mlir::Attribute vectorElts; // ExtVector element subset: V.xyx
170171
mlir::Type elementType;
171172
LValueBaseInfo baseInfo;
172173
const CIRGenBitFieldInfo *bitFieldInfo{nullptr};
@@ -190,6 +191,7 @@ class LValue {
190191
bool isSimple() const { return lvType == Simple; }
191192
bool isVectorElt() const { return lvType == VectorElt; }
192193
bool isBitField() const { return lvType == BitField; }
194+
bool isExtVectorElt() const { return lvType == ExtVectorElt; }
193195
bool isGlobalReg() const { return lvType == GlobalReg; }
194196
bool isVolatile() const { return quals.hasVolatile(); }
195197

@@ -254,6 +256,22 @@ class LValue {
254256
return vectorIdx;
255257
}
256258

259+
// extended vector elements.
260+
Address getExtVectorAddress() const {
261+
assert(isExtVectorElt());
262+
return Address(getExtVectorPointer(), elementType, getAlignment());
263+
}
264+
265+
mlir::Value getExtVectorPointer() const {
266+
assert(isExtVectorElt());
267+
return v;
268+
}
269+
270+
mlir::ArrayAttr getExtVectorElts() const {
271+
assert(isExtVectorElt());
272+
return mlir::cast<mlir::ArrayAttr>(vectorElts);
273+
}
274+
257275
static LValue makeVectorElt(Address vecAddress, mlir::Value index,
258276
clang::QualType t, LValueBaseInfo baseInfo) {
259277
LValue r;
@@ -265,6 +283,19 @@ class LValue {
265283
return r;
266284
}
267285

286+
static LValue makeExtVectorElt(Address vecAddress, mlir::ArrayAttr elts,
287+
clang::QualType type,
288+
LValueBaseInfo baseInfo) {
289+
LValue r;
290+
r.lvType = ExtVectorElt;
291+
r.v = vecAddress.getPointer();
292+
r.elementType = vecAddress.getElementType();
293+
r.vectorElts = elts;
294+
r.initialize(type, type.getQualifiers(), vecAddress.getAlignment(),
295+
baseInfo);
296+
return r;
297+
}
298+
268299
// bitfield lvalue
269300
Address getBitFieldAddress() const {
270301
return Address(getBitFieldPointer(), elementType, getAlignment());
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
6+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
7+
8+
typedef int vi4 __attribute__((ext_vector_type(4)));
9+
10+
void element_expr_from_gl() {
11+
vi4 a;
12+
int x = a.x;
13+
int y = a.y;
14+
}
15+
16+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
17+
// CIR: %[[X_ADDR:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init]
18+
// CIR: %[[Y_ADDR:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init]
19+
// CIR: %[[TMP_A:.*]] = cir.load {{.*}} %[[A_ADDR]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
20+
// CIR: %[[CONST_0:.*]] = cir.const #cir.int<0> : !s64i
21+
// CIR: %[[ELEM_0:.*]] = cir.vec.extract %[[TMP_A]][%[[CONST_0]] : !s64i] : !cir.vector<4 x !s32i>
22+
// CIR: cir.store {{.*}} %[[ELEM_0]], %[[X_ADDR]] : !s32i, !cir.ptr<!s32i>
23+
// CIR: %[[TMP_A:.*]] = cir.load {{.*}} %[[A_ADDR]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
24+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s64i
25+
// CIR: %[[ELEM_1:.*]] = cir.vec.extract %[[TMP_A]][%[[CONST_1]] : !s64i] : !cir.vector<4 x !s32i>
26+
// CIR: cir.store {{.*}} %[[ELEM_1]], %[[Y_ADDR]] : !s32i, !cir.ptr<!s32i>
27+
28+
// LLVM: %[[A_ADDR:.*]] = alloca <4 x i32>, i64 1, align 16
29+
// LLVM: %[[X_ADDR:.*]] = alloca i32, i64 1, align 4
30+
// LLVM: %[[Y_ADDR:.*]] = alloca i32, i64 1, align 4
31+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
32+
// LLVM: %[[ELEM_0:.*]] = extractelement <4 x i32> %4, i64 0
33+
// LLVM: store i32 %[[ELEM_0]], ptr %[[X_ADDR]], align 4
34+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
35+
// LLVM: %[[ELEM_1:.*]] = extractelement <4 x i32> %6, i64 1
36+
// LLVM: store i32 %[[ELEM_1]], ptr %[[Y_ADDR]], align 4
37+
38+
// OGCG: %[[A_ADDR:.*]] = alloca <4 x i32>, align 16
39+
// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4
40+
// OGCG: %[[Y_ADDR:.*]] = alloca i32, align 4
41+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
42+
// OGCG: %[[ELEM_0:.*]] = extractelement <4 x i32> %[[TMP_A]], i64 0
43+
// OGCG: store i32 %[[ELEM_0]], ptr %[[X_ADDR]], align 4
44+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
45+
// OGCG: %[[ELEM_1:.*]] = extractelement <4 x i32> %[[TMP_A]], i64 1
46+
// OGCG: store i32 %[[ELEM_1]], ptr %[[Y_ADDR]], align 4

0 commit comments

Comments
 (0)