Skip to content

Commit 11de7d3

Browse files
committed
[CIR] Upstream insert op for VectorType
1 parent 7feba5f commit 11de7d3

File tree

8 files changed

+397
-10
lines changed

8 files changed

+397
-10
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,6 +1969,43 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
19691969
let hasVerifier = 1;
19701970
}
19711971

1972+
//===----------------------------------------------------------------------===//
1973+
// VecInsertOp
1974+
//===----------------------------------------------------------------------===//
1975+
1976+
def VecInsertOp : CIR_Op<"vec.insert", [Pure,
1977+
TypesMatchWith<"argument type matches vector element type", "vec", "value",
1978+
"cast<VectorType>($_self).getElementType()">,
1979+
AllTypesMatch<["result", "vec"]>]> {
1980+
1981+
let summary = "Insert one element into a vector object";
1982+
let description = [{
1983+
The `cir.vec.insert` operation replaces the element of the given vector at
1984+
the given index with the given value. The new vector with the inserted
1985+
element is returned.
1986+
1987+
```mlir
1988+
%value = cir.const #cir.int<5> : !s32i
1989+
%index = cir.const #cir.int<2> : !s32i
1990+
%vec_tmp = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
1991+
%new_vec = cir.vec.insert %index, %vec_tmp[%value : !s32i] : !cir.vector<4 x !s32i>
1992+
```
1993+
}];
1994+
1995+
let arguments = (ins
1996+
CIR_VectorType:$vec,
1997+
AnyType:$value,
1998+
CIR_AnyFundamentalIntType:$index
1999+
);
2000+
2001+
let results = (outs CIR_VectorType:$result);
2002+
2003+
let assemblyFormat = [{
2004+
$value `,` $vec `[` $index `:` type($index) `]` attr-dict `:`
2005+
qualified(type($vec))
2006+
}];
2007+
}
2008+
19722009
//===----------------------------------------------------------------------===//
19732010
// VecExtractOp
19742011
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,17 @@ Address CIRGenFunction::emitPointerWithAlignment(const Expr *expr,
205205
void CIRGenFunction::emitStoreThroughLValue(RValue src, LValue dst,
206206
bool isInit) {
207207
if (!dst.isSimple()) {
208+
if (dst.isVectorElt()) {
209+
// Read/modify/write the vector, inserting the new element
210+
const mlir::Location loc = dst.getVectorPointer().getLoc();
211+
const mlir::Value vector =
212+
builder.createLoad(loc, dst.getVectorAddress().getPointer());
213+
const mlir::Value newVector = builder.create<cir::VecInsertOp>(
214+
loc, vector, src.getScalarVal(), dst.getVectorIdx());
215+
builder.createStore(loc, newVector, dst.getVectorAddress().getPointer());
216+
return;
217+
}
218+
208219
cgm.errorNYI(dst.getPointer().getLoc(),
209220
"emitStoreThroughLValue: non-simple lvalue");
210221
return;
@@ -418,6 +429,13 @@ RValue CIRGenFunction::emitLoadOfLValue(LValue lv, SourceLocation loc) {
418429
if (lv.isSimple())
419430
return RValue::get(emitLoadOfScalar(lv, loc));
420431

432+
if (lv.isVectorElt()) {
433+
auto load =
434+
builder.createLoad(getLoc(loc), lv.getVectorAddress().getPointer());
435+
return RValue::get(builder.create<cir::VecExtractOp>(getLoc(loc), load,
436+
lv.getVectorIdx()));
437+
}
438+
421439
cgm.errorNYI(loc, "emitLoadOfLValue");
422440
return RValue::get(nullptr);
423441
}
@@ -638,12 +656,6 @@ static Address emitArraySubscriptPtr(CIRGenFunction &cgf,
638656

639657
LValue
640658
CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
641-
if (e->getBase()->getType()->isVectorType() &&
642-
!isa<ExtVectorElementExpr>(e->getBase())) {
643-
cgm.errorNYI(e->getSourceRange(), "emitArraySubscriptExpr: VectorType");
644-
return LValue::makeAddr(Address::invalid(), e->getType(), LValueBaseInfo());
645-
}
646-
647659
if (isa<ExtVectorElementExpr>(e->getBase())) {
648660
cgm.errorNYI(e->getSourceRange(),
649661
"emitArraySubscriptExpr: ExtVectorElementExpr");
@@ -666,18 +678,26 @@ CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
666678
assert((e->getIdx() == e->getLHS() || e->getIdx() == e->getRHS()) &&
667679
"index was neither LHS nor RHS");
668680

669-
auto emitIdxAfterBase = [&]() -> mlir::Value {
681+
auto emitIdxAfterBase = [&](bool promote) -> mlir::Value {
670682
const mlir::Value idx = emitScalarExpr(e->getIdx());
671683

672684
// Extend or truncate the index type to 32 or 64-bits.
673685
auto ptrTy = mlir::dyn_cast<cir::PointerType>(idx.getType());
674-
if (ptrTy && mlir::isa<cir::IntType>(ptrTy.getPointee()))
686+
if (promote && ptrTy && mlir::isa<cir::IntType>(ptrTy.getPointee()))
675687
cgm.errorNYI(e->getSourceRange(),
676688
"emitArraySubscriptExpr: index type cast");
677689
return idx;
678690
};
679691

680-
const mlir::Value idx = emitIdxAfterBase();
692+
if (e->getBase()->getType()->isVectorType() &&
693+
!isa<ExtVectorElementExpr>(e->getBase())) {
694+
const mlir::Value idx = emitIdxAfterBase(/*promote=*/false);
695+
const LValue lhs = emitLValue(e->getBase());
696+
return LValue::makeVectorElt(lhs.getAddress(), idx, e->getBase()->getType(),
697+
lhs.getBaseInfo());
698+
}
699+
700+
const mlir::Value idx = emitIdxAfterBase(/*promote=*/true);
681701
if (const Expr *array = getSimpleArrayDecayOperand(e->getBase())) {
682702
LValue arrayLV;
683703
if (const auto *ase = dyn_cast<ArraySubscriptExpr>(array))

clang/lib/CIR/CodeGen/CIRGenValue.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class LValue {
115115
// this is the alignment of the whole vector)
116116
unsigned alignment;
117117
mlir::Value v;
118+
mlir::Value vectorIdx; // Index for vector subscript
118119
mlir::Type elementType;
119120
LValueBaseInfo baseInfo;
120121

@@ -135,6 +136,7 @@ class LValue {
135136

136137
public:
137138
bool isSimple() const { return lvType == Simple; }
139+
bool isVectorElt() const { return lvType == VectorElt; }
138140
bool isBitField() const { return lvType == BitField; }
139141

140142
// TODO: Add support for volatile
@@ -175,6 +177,31 @@ class LValue {
175177
r.initialize(t, t.getQualifiers(), address.getAlignment(), baseInfo);
176178
return r;
177179
}
180+
181+
Address getVectorAddress() const {
182+
return Address(getVectorPointer(), elementType, getAlignment());
183+
}
184+
185+
mlir::Value getVectorPointer() const {
186+
assert(isVectorElt());
187+
return v;
188+
}
189+
190+
mlir::Value getVectorIdx() const {
191+
assert(isVectorElt());
192+
return vectorIdx;
193+
}
194+
195+
static LValue makeVectorElt(Address vecAddress, mlir::Value index,
196+
clang::QualType t, LValueBaseInfo baseInfo) {
197+
LValue r;
198+
r.lvType = VectorElt;
199+
r.v = vecAddress.getPointer();
200+
r.elementType = vecAddress.getElementType();
201+
r.vectorIdx = index;
202+
r.initialize(t, t.getQualifiers(), vecAddress.getAlignment(), baseInfo);
203+
return r;
204+
}
178205
};
179206

180207
/// An aggregate value slot.

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1601,7 +1601,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
16011601
CIRToLLVMTrapOpLowering,
16021602
CIRToLLVMUnaryOpLowering,
16031603
CIRToLLVMVecCreateOpLowering,
1604-
CIRToLLVMVecExtractOpLowering
1604+
CIRToLLVMVecExtractOpLowering,
1605+
CIRToLLVMVecInsertOpLowering
16051606
// clang-format on
16061607
>(converter, patterns.getContext());
16071608

@@ -1718,6 +1719,14 @@ mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
17181719
return mlir::success();
17191720
}
17201721

1722+
mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
1723+
cir::VecInsertOp op, OpAdaptor adaptor,
1724+
mlir::ConversionPatternRewriter &rewriter) const {
1725+
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>(
1726+
op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex());
1727+
return mlir::success();
1728+
}
1729+
17211730
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
17221731
return std::make_unique<ConvertCIRToLLVMPass>();
17231732
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,16 @@ class CIRToLLVMVecExtractOpLowering
313313
mlir::ConversionPatternRewriter &) const override;
314314
};
315315

316+
class CIRToLLVMVecInsertOpLowering
317+
: public mlir::OpConversionPattern<cir::VecInsertOp> {
318+
public:
319+
using mlir::OpConversionPattern<cir::VecInsertOp>::OpConversionPattern;
320+
321+
mlir::LogicalResult
322+
matchAndRewrite(cir::VecInsertOp op, OpAdaptor,
323+
mlir::ConversionPatternRewriter &) const override;
324+
};
325+
316326
} // namespace direct
317327
} // namespace cir
318328

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,126 @@ void foo4() {
213213
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
214214
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
215215
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
216+
217+
void foo5() {
218+
vi4 a = { 1, 2, 3, 4 };
219+
220+
a[2] = 5;
221+
}
222+
223+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
224+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
225+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
226+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
227+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
228+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
229+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
230+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
231+
// CIR: %[[CONST_VAL:.*]] = cir.const #cir.int<5> : !s32i
232+
// CIR: %[[CONST_IDX:.*]] = cir.const #cir.int<2> : !s32i
233+
// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
234+
// CIR: %[[NEW_VEC:.*]] = cir.vec.insert %[[CONST_VAL]], %[[TMP]][%[[CONST_IDX]] : !s32i] : !cir.vector<4 x !s32i>
235+
// CIR: cir.store %[[NEW_VEC]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
236+
237+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
238+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
239+
// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
240+
// LLVM: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP]], i32 5, i32 2
241+
// LLVM: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
242+
243+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
244+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
245+
// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
246+
// OGCG: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP]], i32 5, i32 2
247+
// OGCG: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
248+
249+
void foo6() {
250+
vi4 a = { 1, 2, 3, 4 };
251+
int idx = 2;
252+
int value = 5;
253+
a[idx] = value;
254+
}
255+
256+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
257+
// CIR: %[[IDX:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["idx", init]
258+
// CIR: %[[VAL:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["value", init]
259+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
260+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
261+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
262+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
263+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
264+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
265+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
266+
// CIR: %[[CONST_IDX:.*]] = cir.const #cir.int<2> : !s32i
267+
// CIR: cir.store %[[CONST_IDX]], %[[IDX]] : !s32i, !cir.ptr<!s32i>
268+
// CIR: %[[CONST_VAL:.*]] = cir.const #cir.int<5> : !s32i
269+
// CIR: cir.store %[[CONST_VAL]], %[[VAL]] : !s32i, !cir.ptr<!s32i>
270+
// CIR: %[[TMP1:.*]] = cir.load %[[VAL]] : !cir.ptr<!s32i>, !s32i
271+
// CIR: %[[TMP2:.*]] = cir.load %[[IDX]] : !cir.ptr<!s32i>, !s32i
272+
// CIR: %[[TMP3:.*]] = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
273+
// CIR: %[[NEW_VEC:.*]] = cir.vec.insert %[[TMP1]], %[[TMP3]][%[[TMP2]] : !s32i] : !cir.vector<4 x !s32i>
274+
// CIR: cir.store %[[NEW_VEC]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
275+
276+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
277+
// LLVM: %[[IDX:.*]] = alloca i32, i64 1, align 4
278+
// LLVM: %[[VAL:.*]] = alloca i32, i64 1, align 4
279+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %1, align 16
280+
// LLVM: store i32 2, ptr %[[IDX]], align 4
281+
// LLVM: store i32 5, ptr %[[VAL]], align 4
282+
// LLVM: %[[TMP1:.*]] = load i32, ptr %[[VAL]], align 4
283+
// LLVM: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
284+
// LLVM: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
285+
// LLVM: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP3]], i32 %[[TMP1]], i32 %[[TMP2]]
286+
// LLVM: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
287+
288+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
289+
// OGCG: %[[IDX:.*]] = alloca i32, align 4
290+
// OGCG: %[[VAL:.*]] = alloca i32, align 4
291+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
292+
// OGCG: store i32 2, ptr %[[IDX]], align 4
293+
// OGCG: store i32 5, ptr %[[VAL]], align 4
294+
// OGCG: %[[TMP1:.*]] = load i32, ptr %[[VAL]], align 4
295+
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
296+
// OGCG: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
297+
// OGCG: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP3]], i32 %[[TMP1]], i32 %[[TMP2]]
298+
// OGCG: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
299+
300+
void foo7() {
301+
vi4 a = {1, 2, 3, 4};
302+
a[2] += 5;
303+
}
304+
305+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
306+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
307+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
308+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
309+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
310+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
311+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
312+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
313+
// CIR: %[[CONST_VAL:.*]] = cir.const #cir.int<5> : !s32i
314+
// CIR: %[[CONST_IDX:.*]] = cir.const #cir.int<2> : !s32i
315+
// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
316+
// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[CONST_IDX]] : !s32i] : !cir.vector<4 x !s32i>
317+
// CIR: %[[RES:.*]] = cir.binop(add, %[[ELE]], %[[CONST_VAL]]) nsw : !s32i
318+
// CIR: %[[TMP2:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
319+
// CIR: %[[NEW_VEC:.*]] = cir.vec.insert %[[RES]], %[[TMP2]][%[[CONST_IDX]] : !s32i] : !cir.vector<4 x !s32i>
320+
// CIR: cir.store %[[NEW_VEC]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
321+
322+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
323+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
324+
// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
325+
// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 2
326+
// LLVM: %[[RES:.*]] = add nsw i32 %[[ELE]], 5
327+
// LLVM: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
328+
// LLVM: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP2]], i32 %[[RES]], i32 2
329+
// LLVM: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16
330+
331+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
332+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
333+
// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
334+
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 2
335+
// OGCG: %[[RES:.*]] = add nsw i32 %[[ELE]], 5
336+
// OGCG: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
337+
// OGCG: %[[NEW_VEC:.*]] = insertelement <4 x i32> %[[TMP2]], i32 %[[RES]], i32 2
338+
// OGCG: store <4 x i32> %[[NEW_VEC]], ptr %[[VEC]], align 16

0 commit comments

Comments
 (0)