Skip to content

Commit 0a1b06c

Browse files
authored
[CIR][CIRGen][Builtin][Neon] Lower neon_vqmovns_s32 and add CIR PoisonAttr (#1199)
CIR PoisonOp is needed in this context as alternative would be to use VecCreateOp to prepare an arg for VecInsertElement, but VecCreate is for different purpose and [it would insert all elements](https://github.com/llvm/clangir/blob/eacaabba76ebdbf87217fefaa77f92c45cf4509c/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp#L1679) which is not totally unnecessary in this context. Here is the [intrinsic def ](https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=vqmovns_)
1 parent 7eb09a7 commit 0a1b06c

File tree

7 files changed

+115
-13
lines changed

7 files changed

+115
-13
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,21 @@ def UndefAttr : CIR_Attr<"Undef", "undef", [TypedAttrInterface]> {
169169
let assemblyFormat = [{}];
170170
}
171171

172+
//===----------------------------------------------------------------------===//
173+
// PoisonAttr
174+
//===----------------------------------------------------------------------===//
175+
176+
def PoisonAttr : CIR_Attr<"Poison", "poison", [TypedAttrInterface]> {
177+
let summary = "Represent an poison constant";
178+
let description = [{
179+
The PoisonAttr represents an poison constant, corresponding to LLVM's notion
180+
of poison.
181+
}];
182+
183+
let parameters = (ins AttributeSelfTypeParameter<"">:$type);
184+
let assemblyFormat = [{}];
185+
}
186+
172187
//===----------------------------------------------------------------------===//
173188
// ConstArrayAttr
174189
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,6 +2265,32 @@ static mlir::Value emitNeonRShiftImm(CIRGenFunction &cgf, mlir::Value shiftVec,
22652265
false /* right shift */);
22662266
}
22672267

2268+
/// Vectorize value, usually for argument of a neon SISD intrinsic call.
2269+
static void vecExtendIntValue(CIRGenFunction &cgf, cir::VectorType argVTy,
2270+
mlir::Value &arg, mlir::Location loc) {
2271+
CIRGenBuilderTy &builder = cgf.getBuilder();
2272+
cir::IntType eltTy = mlir::dyn_cast<cir::IntType>(argVTy.getEltType());
2273+
assert(mlir::isa<cir::IntType>(arg.getType()) && eltTy);
2274+
// The constant argument to an _n_ intrinsic always has Int32Ty, so truncate
2275+
// it before inserting.
2276+
arg = builder.createIntCast(arg, eltTy);
2277+
mlir::Value zero = builder.getConstInt(loc, cgf.SizeTy, 0);
2278+
mlir::Value poison = builder.create<cir::ConstantOp>(
2279+
loc, eltTy, builder.getAttr<cir::PoisonAttr>(eltTy));
2280+
arg = builder.create<cir::VecInsertOp>(
2281+
loc, builder.create<cir::VecSplatOp>(loc, argVTy, poison), arg, zero);
2282+
}
2283+
2284+
/// Reduce vector type value to scalar, usually for result of a
2285+
/// neon SISD intrinsic call
2286+
static mlir::Value vecReduceIntValue(CIRGenFunction &cgf, mlir::Value val,
2287+
mlir::Location loc) {
2288+
CIRGenBuilderTy &builder = cgf.getBuilder();
2289+
assert(mlir::isa<cir::VectorType>(val.getType()));
2290+
return builder.create<cir::VecExtractOp>(
2291+
loc, val, builder.getConstInt(loc, cgf.SizeTy, 0));
2292+
}
2293+
22682294
mlir::Value emitNeonCall(CIRGenBuilderTy &builder,
22692295
llvm::SmallVector<mlir::Type> argTypes,
22702296
llvm::SmallVectorImpl<mlir::Value> &args,
@@ -2853,8 +2879,17 @@ static mlir::Value emitCommonNeonSISDBuiltinExpr(
28532879
llvm_unreachable(" neon_vqmovnh_s16 NYI ");
28542880
case NEON::BI__builtin_neon_vqmovnh_u16:
28552881
llvm_unreachable(" neon_vqmovnh_u16 NYI ");
2856-
case NEON::BI__builtin_neon_vqmovns_s32:
2857-
llvm_unreachable(" neon_vqmovns_s32 NYI ");
2882+
case NEON::BI__builtin_neon_vqmovns_s32: {
2883+
mlir::Location loc = cgf.getLoc(expr->getExprLoc());
2884+
cir::VectorType argVecTy =
2885+
cir::VectorType::get(&(cgf.getMLIRContext()), cgf.SInt32Ty, 4);
2886+
cir::VectorType resVecTy =
2887+
cir::VectorType::get(&(cgf.getMLIRContext()), cgf.SInt16Ty, 4);
2888+
vecExtendIntValue(cgf, argVecTy, ops[0], loc);
2889+
mlir::Value result = emitNeonCall(builder, {argVecTy}, ops,
2890+
"aarch64.neon.sqxtn", resVecTy, loc);
2891+
return vecReduceIntValue(cgf, result, loc);
2892+
}
28582893
case NEON::BI__builtin_neon_vqmovns_u32:
28592894
llvm_unreachable(" neon_vqmovns_u32 NYI ");
28602895
case NEON::BI__builtin_neon_vqmovund_s64:

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,12 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
389389
return op->emitOpError("undef expects non-void type");
390390
}
391391

392+
if (isa<cir::PoisonAttr>(attrType)) {
393+
if (!::mlir::isa<cir::VoidType>(opType))
394+
return success();
395+
return op->emitOpError("poison expects non-void type");
396+
}
397+
392398
if (mlir::isa<cir::BoolAttr>(attrType)) {
393399
if (!mlir::isa<cir::BoolType>(opType))
394400
return op->emitOpError("result type (")

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,16 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::UndefAttr undefAttr,
425425
loc, converter->convertType(undefAttr.getType()));
426426
}
427427

428+
/// PoisonAttr visitor.
429+
static mlir::Value
430+
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::PoisonAttr poisonAttr,
431+
mlir::ConversionPatternRewriter &rewriter,
432+
const mlir::TypeConverter *converter) {
433+
auto loc = parentOp->getLoc();
434+
return rewriter.create<mlir::LLVM::PoisonOp>(
435+
loc, converter->convertType(poisonAttr.getType()));
436+
}
437+
428438
/// ConstStruct visitor.
429439
static mlir::Value
430440
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct,
@@ -644,6 +654,8 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
644654
return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter);
645655
if (const auto undefAttr = mlir::dyn_cast<cir::UndefAttr>(attr))
646656
return lowerCirAttrAsValue(parentOp, undefAttr, rewriter, converter);
657+
if (const auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr))
658+
return lowerCirAttrAsValue(parentOp, poisonAttr, rewriter, converter);
647659
if (const auto globalAttr = mlir::dyn_cast<cir::GlobalViewAttr>(attr))
648660
return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter);
649661
if (const auto vtableAttr = mlir::dyn_cast<cir::VTableAttr>(attr))
@@ -1555,6 +1567,14 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
15551567
mlir::ConversionPatternRewriter &rewriter) const {
15561568
mlir::Attribute attr = op.getValue();
15571569

1570+
// Regardless of the type, we should lower the constant of poison value
1571+
// into PoisonOp.
1572+
if (mlir::isa<cir::PoisonAttr>(attr)) {
1573+
rewriter.replaceOp(
1574+
op, lowerCirAttrAsValue(op, attr, rewriter, getTypeConverter()));
1575+
return mlir::success();
1576+
}
1577+
15581578
if (mlir::isa<mlir::IntegerType>(op.getType())) {
15591579
// Verified cir.const operations cannot actually be of these types, but the
15601580
// lowering pass may generate temporary cir.const operations with these
@@ -1695,6 +1715,7 @@ mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
16951715
mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
16961716
assert(vecTy.getSize() == op.getElements().size() &&
16971717
"cir.vec.create op count doesn't match vector type elements count");
1718+
16981719
for (uint64_t i = 0; i < vecTy.getSize(); ++i) {
16991720
mlir::Value indexValue =
17001721
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
@@ -1745,15 +1766,21 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
17451766
assert(vecTy && "result type of cir.vec.splat op is not VectorType");
17461767
auto llvmTy = typeConverter->convertType(vecTy);
17471768
auto loc = op.getLoc();
1748-
mlir::Value undef = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
1769+
mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
17491770
mlir::Value indexValue =
17501771
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
17511772
mlir::Value elementValue = adaptor.getValue();
1773+
if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
1774+
// If the splat value is poison, then we can just use poison value
1775+
// for the entire vector.
1776+
rewriter.replaceOp(op, poison);
1777+
return mlir::success();
1778+
}
17521779
mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
1753-
loc, undef, elementValue, indexValue);
1780+
loc, poison, elementValue, indexValue);
17541781
SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
17551782
mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
1756-
loc, oneElement, undef, zeroValues);
1783+
loc, oneElement, poison, zeroValues);
17571784
rewriter.replaceOp(op, shuffled);
17581785
return mlir::success();
17591786
}

clang/test/CIR/CodeGen/AArch64/neon.c

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14611,14 +14611,25 @@ void test_vst1q_s64(int64_t *a, int64x2_t b) {
1461114611
// return (int8_t)vqmovnh_s16(a);
1461214612
// }
1461314613

14614-
// NYI-LABEL: @test_vqmovns_s32(
14615-
// NYI: [[TMP0:%.*]] = insertelement <4 x i32> poison, i32 %a, i64 0
14616-
// NYI: [[VQMOVNS_S32_I:%.*]] = call <4 x i16> @llvm.aarch64.neon.sqxtn.v4i16(<4 x i32> [[TMP0]])
14617-
// NYI: [[TMP1:%.*]] = extractelement <4 x i16> [[VQMOVNS_S32_I]], i64 0
14618-
// NYI: ret i16 [[TMP1]]
14619-
// int16_t test_vqmovns_s32(int32_t a) {
14620-
// return (int16_t)vqmovns_s32(a);
14621-
// }
14614+
int16_t test_vqmovns_s32(int32_t a) {
14615+
return (int16_t)vqmovns_s32(a);
14616+
14617+
// CIR-LABEL: vqmovns_s32
14618+
// CIR: [[A:%.*]] = cir.load {{.*}} : !cir.ptr<!s32i>, !s32i
14619+
// CIR: [[VQMOVNS_S32_ZERO1:%.*]] = cir.const #cir.int<0> : !u64i
14620+
// CIR: [[POISON:%.*]] = cir.const #cir.poison : !s32i
14621+
// CIR: [[POISON_VEC:%.*]] = cir.vec.splat [[POISON]] : !s32i, !cir.vector<!s32i x 4>
14622+
// CIR: [[TMP0:%.*]] = cir.vec.insert [[A]], [[POISON_VEC]][[[VQMOVNS_S32_ZERO1]] : !u64i] : !cir.vector<!s32i x 4>
14623+
// CIR: [[VQMOVNS_S32_I:%.*]] = cir.llvm.intrinsic "aarch64.neon.sqxtn" [[TMP0]] : (!cir.vector<!s32i x 4>) -> !cir.vector<!s16i x 4>
14624+
// CIR: [[VQMOVNS_S32_ZERO2:%.*]] = cir.const #cir.int<0> : !u64i
14625+
// CIR: [[TMP1:%.*]] = cir.vec.extract [[VQMOVNS_S32_I]][[[VQMOVNS_S32_ZERO2]] : !u64i] : !cir.vector<!s16i x 4>
14626+
14627+
// LLVM: {{.*}}@test_vqmovns_s32(i32{{.*}}[[a:%.*]])
14628+
// LLVM: [[TMP0:%.*]] = insertelement <4 x i32> poison, i32 [[a]], i64 0
14629+
// LLVM: [[VQMOVNS_S32_I:%.*]] = call <4 x i16> @llvm.aarch64.neon.sqxtn.v4i16(<4 x i32> [[TMP0]])
14630+
// LLVM: [[TMP1:%.*]] = extractelement <4 x i16> [[VQMOVNS_S32_I]], i64 0
14631+
// LLVM: ret i16 [[TMP1]]
14632+
}
1462214633

1462314634
// NYI-LABEL: @test_vqmovnd_s64(
1462414635
// NYI: [[VQMOVND_S64_I:%.*]] = call i32 @llvm.aarch64.neon.scalar.sqxtn.i32.i64(i64 %a)

clang/test/CIR/IR/invalid.cir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,12 @@ module {
393393

394394
// -----
395395

396+
module {
397+
cir.global external @v = #cir.poison : !cir.void // expected-error {{poison expects non-void type}}
398+
}
399+
400+
// -----
401+
396402
!s32i = !cir.int<s, 32>
397403
cir.func @vec_op_size() {
398404
%0 = cir.const #cir.int<1> : !s32i

clang/test/CIR/Lowering/const.cir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ module {
1717
// CHECK: llvm.mlir.zero : !llvm.array<3 x i32>
1818
%5 = cir.const #cir.undef : !cir.array<!s32i x 3>
1919
// CHECK: llvm.mlir.undef : !llvm.array<3 x i32>
20+
%6 = cir.const #cir.poison : !s32i
21+
// CHECK: llvm.mlir.poison : i32
2022
cir.return
2123
}
2224

0 commit comments

Comments
 (0)