Skip to content

Commit 990d969

Browse files
committed
[CIR] Upstream basic support for ExtVector element expr
1 parent 4d27413 commit 990d969

File tree

6 files changed

+173
-1
lines changed

6 files changed

+173
-1
lines changed

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,10 +627,51 @@ RValue CIRGenFunction::emitLoadOfLValue(LValue lv, SourceLocation loc) {
627627
lv.getVectorIdx()));
628628
}
629629

630+
if (lv.isExtVectorElt()) {
631+
return emitLoadOfExtVectorElementLValue(lv);
632+
}
633+
630634
cgm.errorNYI(loc, "emitLoadOfLValue");
631635
return RValue::get(nullptr);
632636
}
633637

638+
int64_t CIRGenFunction::getAccessedFieldNo(unsigned int idx,
639+
const mlir::ArrayAttr elts) {
640+
auto elt = mlir::dyn_cast<mlir::IntegerAttr>(elts[idx]);
641+
assert(elt && "The indices should be integer attributes");
642+
return elt.getInt();
643+
}
644+
645+
// If this is a reference to a subset of the elements of a vector, create an
646+
// appropriate shufflevector.
647+
RValue CIRGenFunction::emitLoadOfExtVectorElementLValue(LValue lv) {
648+
mlir::Location loc = lv.getExtVectorPointer().getLoc();
649+
mlir::Value vec = builder.createLoad(loc, lv.getExtVectorAddress());
650+
651+
// HLSL allows treating scalars as one-element vectors. Converting the scalar
652+
// IR value to a vector here allows the rest of codegen to behave as normal.
653+
if (getLangOpts().HLSL && !mlir::isa<cir::VectorType>(vec.getType())) {
654+
cgm.errorNYI(loc, "emitLoadOfExtVectorElementLValue: HLSL");
655+
return {};
656+
}
657+
658+
const mlir::ArrayAttr elts = lv.getExtVectorElts();
659+
660+
// If the result of the expression is a non-vector type, we must be extracting
661+
// a single element. Just codegen as an extractelement.
662+
const auto *exprVecTy = lv.getType()->getAs<clang::VectorType>();
663+
if (!exprVecTy) {
664+
int64_t indexValue = getAccessedFieldNo(0, elts);
665+
cir::ConstantOp index =
666+
builder.getConstInt(loc, builder.getSInt64Ty(), indexValue);
667+
return RValue::get(cir::VecExtractOp::create(builder, loc, vec, index));
668+
}
669+
670+
cgm.errorNYI(
671+
loc, "emitLoadOfExtVectorElementLValue: Result of expr is vector type");
672+
return {};
673+
}
674+
634675
static cir::FuncOp emitFunctionDeclPointer(CIRGenModule &cgm, GlobalDecl gd) {
635676
assert(!cir::MissingFeatures::weakRefReference());
636677
return cgm.getAddrOfFunction(gd);
@@ -1116,6 +1157,50 @@ CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
11161157
return lv;
11171158
}
11181159

1160+
LValue CIRGenFunction::emitExtVectorElementExpr(const ExtVectorElementExpr *e) {
1161+
// Emit the base vector as an l-value.
1162+
LValue base;
1163+
1164+
// ExtVectorElementExpr's base can either be a vector or pointer to vector.
1165+
if (e->isArrow()) {
1166+
cgm.errorNYI(e->getSourceRange(),
1167+
"emitExtVectorElementExpr: pointer to vector");
1168+
return {};
1169+
} else if (e->getBase()->isGLValue()) {
1170+
// Otherwise, if the base is an lvalue ( as in the case of foo.x.x),
1171+
// emit the base as an lvalue.
1172+
assert(e->getBase()->getType()->isVectorType());
1173+
base = emitLValue(e->getBase());
1174+
} else {
1175+
// Otherwise, the base is a normal rvalue (as in (V+V).x), emit it as such.
1176+
cgm.errorNYI(e->getSourceRange(),
1177+
"emitExtVectorElementExpr: base is a normal rvalue");
1178+
return {};
1179+
}
1180+
1181+
QualType type =
1182+
e->getType().withCVRQualifiers(base.getQuals().getCVRQualifiers());
1183+
1184+
// Encode the element access list into a vector of unsigned indices.
1185+
SmallVector<uint32_t, 4> indices;
1186+
e->getEncodedElementAccess(indices);
1187+
1188+
if (base.isSimple()) {
1189+
SmallVector<int64_t> attrElts;
1190+
for (uint32_t i : indices) {
1191+
attrElts.push_back(static_cast<int64_t>(i));
1192+
}
1193+
1194+
mlir::ArrayAttr elts = builder.getI64ArrayAttr(attrElts);
1195+
return LValue::makeExtVectorElt(base.getAddress(), elts, type,
1196+
base.getBaseInfo());
1197+
}
1198+
1199+
cgm.errorNYI(e->getSourceRange(),
1200+
"emitExtVectorElementExpr: isSimple is false");
1201+
return {};
1202+
}
1203+
11191204
LValue CIRGenFunction::emitStringLiteralLValue(const StringLiteral *e,
11201205
llvm::StringRef name) {
11211206
cir::GlobalOp globalOp = cgm.getGlobalForStringLiteral(e, name);

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
277277
e->getSourceRange().getBegin());
278278
}
279279

280+
mlir::Value VisitExtVectorElementExpr(Expr *e) { return emitLoadOfLValue(e); }
281+
280282
mlir::Value VisitMemberExpr(MemberExpr *e);
281283

282284
mlir::Value VisitCompoundLiteralExpr(CompoundLiteralExpr *e) {

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,8 @@ LValue CIRGenFunction::emitLValue(const Expr *e) {
883883
return emitConditionalOperatorLValue(cast<BinaryConditionalOperator>(e));
884884
case Expr::ArraySubscriptExprClass:
885885
return emitArraySubscriptExpr(cast<ArraySubscriptExpr>(e));
886+
case Expr::ExtVectorElementExprClass:
887+
return emitExtVectorElementExpr(cast<ExtVectorElementExpr>(e));
886888
case Expr::UnaryOperatorClass:
887889
return emitUnaryOpLValue(cast<UnaryOperator>(e));
888890
case Expr::StringLiteralClass:

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,6 +1264,8 @@ class CIRGenFunction : public CIRGenTypeCache {
12641264
QualType &baseType, Address &addr);
12651265
LValue emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e);
12661266

1267+
LValue emitExtVectorElementExpr(const ExtVectorElementExpr *e);
1268+
12671269
Address emitArrayToPointerDecay(const Expr *e,
12681270
LValueBaseInfo *baseInfo = nullptr);
12691271

@@ -1329,6 +1331,8 @@ class CIRGenFunction : public CIRGenTypeCache {
13291331
mlir::Value emittedE,
13301332
bool isDynamic);
13311333

1334+
int64_t getAccessedFieldNo(unsigned idx, mlir::ArrayAttr elts);
1335+
13321336
RValue emitCall(const CIRGenFunctionInfo &funcInfo,
13331337
const CIRGenCallee &callee, ReturnValueSlot returnValue,
13341338
const CallArgList &args, cir::CIRCallOpInterface *callOp,
@@ -1624,6 +1628,8 @@ class CIRGenFunction : public CIRGenTypeCache {
16241628
/// Load a complex number from the specified l-value.
16251629
mlir::Value emitLoadOfComplex(LValue src, SourceLocation loc);
16261630

1631+
RValue emitLoadOfExtVectorElementLValue(LValue lv);
1632+
16271633
/// Given an expression that represents a value lvalue, this method emits
16281634
/// the address of the lvalue, then loads the result as an rvalue,
16291635
/// returning the rvalue.

clang/lib/CIR/CodeGen/CIRGenValue.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ class LValue {
166166
// this is the alignment of the whole vector)
167167
unsigned alignment;
168168
mlir::Value v;
169-
mlir::Value vectorIdx; // Index for vector subscript
169+
mlir::Value vectorIdx; // Index for vector subscript
170+
mlir::Attribute vectorElts; // ExtVector element subset: V.xyx
170171
mlir::Type elementType;
171172
LValueBaseInfo baseInfo;
172173
const CIRGenBitFieldInfo *bitFieldInfo{nullptr};
@@ -190,6 +191,7 @@ class LValue {
190191
bool isSimple() const { return lvType == Simple; }
191192
bool isVectorElt() const { return lvType == VectorElt; }
192193
bool isBitField() const { return lvType == BitField; }
194+
bool isExtVectorElt() const { return lvType == ExtVectorElt; }
193195
bool isGlobalReg() const { return lvType == GlobalReg; }
194196
bool isVolatile() const { return quals.hasVolatile(); }
195197

@@ -254,6 +256,22 @@ class LValue {
254256
return vectorIdx;
255257
}
256258

259+
// extended vector elements.
260+
Address getExtVectorAddress() const {
261+
assert(isExtVectorElt());
262+
return Address(getExtVectorPointer(), elementType, getAlignment());
263+
}
264+
265+
mlir::Value getExtVectorPointer() const {
266+
assert(isExtVectorElt());
267+
return v;
268+
}
269+
270+
mlir::ArrayAttr getExtVectorElts() const {
271+
assert(isExtVectorElt());
272+
return mlir::cast<mlir::ArrayAttr>(vectorElts);
273+
}
274+
257275
static LValue makeVectorElt(Address vecAddress, mlir::Value index,
258276
clang::QualType t, LValueBaseInfo baseInfo) {
259277
LValue r;
@@ -265,6 +283,19 @@ class LValue {
265283
return r;
266284
}
267285

286+
static LValue makeExtVectorElt(Address vecAddress, mlir::ArrayAttr elts,
287+
clang::QualType type,
288+
LValueBaseInfo baseInfo) {
289+
LValue r;
290+
r.lvType = ExtVectorElt;
291+
r.v = vecAddress.getPointer();
292+
r.elementType = vecAddress.getElementType();
293+
r.vectorElts = elts;
294+
r.initialize(type, type.getQualifiers(), vecAddress.getAlignment(),
295+
baseInfo);
296+
return r;
297+
}
298+
268299
// bitfield lvalue
269300
Address getBitFieldAddress() const {
270301
return Address(getBitFieldPointer(), elementType, getAlignment());
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
6+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
7+
8+
typedef int vi4 __attribute__((ext_vector_type(4)));
9+
10+
void element_expr_from_gl() {
11+
vi4 a;
12+
int x = a.x;
13+
int y = a.y;
14+
}
15+
16+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
17+
// CIR: %[[X_ADDR:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init]
18+
// CIR: %[[Y_ADDR:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init]
19+
// CIR: %[[TMP_A:.*]] = cir.load {{.*}} %[[A_ADDR]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
20+
// CIR: %[[CONST_0:.*]] = cir.const #cir.int<0> : !s64i
21+
// CIR: %[[ELEM_0:.*]] = cir.vec.extract %[[TMP_A]][%[[CONST_0]] : !s64i] : !cir.vector<4 x !s32i>
22+
// CIR: cir.store {{.*}} %[[ELEM_0]], %[[X_ADDR]] : !s32i, !cir.ptr<!s32i>
23+
// CIR: %[[TMP_A:.*]] = cir.load {{.*}} %[[A_ADDR]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
24+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s64i
25+
// CIR: %[[ELEM_1:.*]] = cir.vec.extract %[[TMP_A]][%[[CONST_1]] : !s64i] : !cir.vector<4 x !s32i>
26+
// CIR: cir.store {{.*}} %[[ELEM_1]], %[[Y_ADDR]] : !s32i, !cir.ptr<!s32i>
27+
28+
// LLVM: %[[A_ADDR:.*]] = alloca <4 x i32>, i64 1, align 16
29+
// LLVM: %[[X_ADDR:.*]] = alloca i32, i64 1, align 4
30+
// LLVM: %[[Y_ADDR:.*]] = alloca i32, i64 1, align 4
31+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
32+
// LLVM: %[[ELEM_0:.*]] = extractelement <4 x i32> %4, i64 0
33+
// LLVM: store i32 %[[ELEM_0]], ptr %[[X_ADDR]], align 4
34+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
35+
// LLVM: %[[ELEM_1:.*]] = extractelement <4 x i32> %6, i64 1
36+
// LLVM: store i32 %[[ELEM_1]], ptr %[[Y_ADDR]], align 4
37+
38+
// OGCG: %[[A_ADDR:.*]] = alloca <4 x i32>, align 16
39+
// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4
40+
// OGCG: %[[Y_ADDR:.*]] = alloca i32, align 4
41+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
42+
// OGCG: %[[ELEM_0:.*]] = extractelement <4 x i32> %[[TMP_A]], i64 0
43+
// OGCG: store i32 %[[ELEM_0]], ptr %[[X_ADDR]], align 4
44+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[A_ADDR]], align 16
45+
// OGCG: %[[ELEM_1:.*]] = extractelement <4 x i32> %[[TMP_A]], i64 1
46+
// OGCG: store i32 %[[ELEM_1]], ptr %[[Y_ADDR]], align 4

0 commit comments

Comments
 (0)