Skip to content

Commit 05b7720

Browse files
Priyanshu3820Copilotandykaylor
authored
[CIR][X86] Implement lowering for sqrt builtins (#169310)
Implements CIR IR generation for X86-specific sqrt builtin functions, addressing issue #167765. ## Test Results Successfully tested the implementation locally. All tests pass: ```bash $ ./bin/llvm-lit -v ../clang/test/CIR/CodeGen/X86/cir-sqrt-builtins.c Testing: 1 tests, 1 workers PASS: Clang :: CIR/CodeGen/X86/cir-sqrt-builtins.c (1 of 1) Testing Time: 1.18s Total Discovered Tests: 1 Passed: 1 (100.00%) ``` --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: Andy Kaylor <[email protected]>
1 parent 786498b commit 05b7720

File tree

4 files changed

+81
-3
lines changed

4 files changed

+81
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4756,6 +4756,27 @@ class CIR_UnaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
47564756
let llvmOp = llvmOpName;
47574757
}
47584758

4759+
def CIR_SqrtOp : CIR_UnaryFPToFPBuiltinOp<"sqrt", "SqrtOp"> {
4760+
let summary = "Floating-point square root operation";
4761+
4762+
let description = [{
4763+
Computes the square root of a floating-point value or vector.
4764+
4765+
The input must be either:
4766+
• a floating-point scalar type, or
4767+
• a vector whose element type is floating-point.
4768+
4769+
The result type must match the input type exactly.
4770+
4771+
Examples:
4772+
// scalar
4773+
%r = cir.sqrt %x : !cir.fp64
4774+
4775+
// vector
4776+
%v = cir.sqrt %vec : !cir.vector<!cir.fp32 x 4>
4777+
}];
4778+
}
4779+
47594780
def CIR_ACosOp : CIR_UnaryFPToFPBuiltinOp<"acos", "ACosOp"> {
47604781
let summary = "Computes the arcus cosine of the specified value";
47614782
let description = [{

clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,13 +1347,17 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
13471347
case X86::BI__builtin_ia32_sqrtsh_round_mask:
13481348
case X86::BI__builtin_ia32_sqrtsd_round_mask:
13491349
case X86::BI__builtin_ia32_sqrtss_round_mask:
1350-
case X86::BI__builtin_ia32_sqrtph512:
1351-
case X86::BI__builtin_ia32_sqrtps512:
1352-
case X86::BI__builtin_ia32_sqrtpd512:
13531350
cgm.errorNYI(expr->getSourceRange(),
13541351
std::string("unimplemented X86 builtin call: ") +
13551352
getContext().BuiltinInfo.getName(builtinID));
13561353
return {};
1354+
case X86::BI__builtin_ia32_sqrtph512:
1355+
case X86::BI__builtin_ia32_sqrtps512:
1356+
case X86::BI__builtin_ia32_sqrtpd512: {
1357+
mlir::Location loc = getLoc(expr->getExprLoc());
1358+
mlir::Value arg = ops[0];
1359+
return cir::SqrtOp::create(builder, loc, arg.getType(), arg).getResult();
1360+
}
13571361
case X86::BI__builtin_ia32_pmuludq128:
13581362
case X86::BI__builtin_ia32_pmuludq256:
13591363
case X86::BI__builtin_ia32_pmuludq512: {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,14 @@ mlir::LogicalResult CIRToLLVMCopyOpLowering::matchAndRewrite(
186186
return mlir::success();
187187
}
188188

189+
mlir::LogicalResult CIRToLLVMSqrtOpLowering::matchAndRewrite(
190+
cir::SqrtOp op, OpAdaptor adaptor,
191+
mlir::ConversionPatternRewriter &rewriter) const {
192+
mlir::Type resTy = typeConverter->convertType(op.getType());
193+
rewriter.replaceOpWithNewOp<mlir::LLVM::SqrtOp>(op, resTy, adaptor.getSrc());
194+
return mlir::success();
195+
}
196+
189197
mlir::LogicalResult CIRToLLVMCosOpLowering::matchAndRewrite(
190198
cir::CosOp op, OpAdaptor adaptor,
191199
mlir::ConversionPatternRewriter &rewriter) const {
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Test X86-specific sqrt builtins
2+
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
4+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
6+
// RUN: FileCheck --check-prefix=LLVM --input-file=%t-cir.ll %s
7+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -target-feature +avx512f -target-feature +avx512fp16 -emit-llvm %s -o %t.ll
8+
// RUN: FileCheck --check-prefix=OGCG --input-file=%t.ll %s
9+
10+
typedef float __m512 __attribute__((__vector_size__(64), __aligned__(64)));
11+
typedef double __m512d __attribute__((__vector_size__(64), __aligned__(64)));
12+
typedef _Float16 __m512h __attribute__((__vector_size__(64), __aligned__(64)));
13+
14+
// Test __builtin_ia32_sqrtph512
15+
__m512h test_sqrtph512(__m512h a) {
16+
return __builtin_ia32_sqrtph512(a, 4);
17+
}
18+
// CIR-LABEL: cir.func {{.*}}@test_sqrtph512
19+
// CIR: cir.sqrt {{%.*}} : !cir.vector<32 x !cir.f16>
20+
// LLVM-LABEL: define {{.*}} @test_sqrtph512
21+
// LLVM: call <32 x half> @llvm.sqrt.v32f16
22+
// OGCG-LABEL: define {{.*}} @test_sqrtph512
23+
// OGCG: call <32 x half> @llvm.sqrt.v32f16
24+
25+
// Test __builtin_ia32_sqrtps512
26+
__m512 test_sqrtps512(__m512 a) {
27+
return __builtin_ia32_sqrtps512(a, 4);
28+
}
29+
// CIR-LABEL: cir.func {{.*}}@test_sqrtps512
30+
// CIR: cir.sqrt {{%.*}} : !cir.vector<16 x !cir.float>
31+
// LLVM-LABEL: define {{.*}} @test_sqrtps512
32+
// LLVM: call <16 x float> @llvm.sqrt.v16f32
33+
// OGCG-LABEL: define {{.*}} @test_sqrtps512
34+
// OGCG: call <16 x float> @llvm.sqrt.v16f32
35+
36+
// Test __builtin_ia32_sqrtpd512
37+
__m512d test_sqrtpd512(__m512d a) {
38+
return __builtin_ia32_sqrtpd512(a, 4);
39+
}
40+
// CIR-LABEL: cir.func {{.*}}@test_sqrtpd512
41+
// CIR: cir.sqrt {{%.*}} : !cir.vector<8 x !cir.double>
42+
// LLVM-LABEL: define {{.*}} @test_sqrtpd512
43+
// LLVM: call <8 x double> @llvm.sqrt.v8f64
44+
// OGCG-LABEL: define {{.*}} @test_sqrtpd512
45+
// OGCG: call <8 x double> @llvm.sqrt.v8f64

0 commit comments

Comments
 (0)