Skip to content

Commit 82efd72

Browse files
[MLIR] Add sincos op to math dialect (#160772)
Now that `sincos` is a supported intrinsic in the LLVM dialect (#160561) we are able to add the corresponding operation in the math dialect and add conversion patterns for LLVM and NVVM. We have several benchmarks that use sine and cosine in hot-loops, and saving some calculations by performing them together can benefit performance. We would like to have a way to represent sincos in the math dialect.
1 parent 98d43ef commit 82efd72

File tree

7 files changed

+239
-1
lines changed

7 files changed

+239
-1
lines changed

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,43 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
510510
let hasFolder = 1;
511511
}
512512

513+
//===----------------------------------------------------------------------===//
514+
// SinCosOp
515+
//===----------------------------------------------------------------------===//
516+
517+
def Math_SincosOp : Math_Op<"sincos",
518+
[SameOperandsAndResultShape,
519+
DeclareOpInterfaceMethods<ArithFastMathInterface>,
520+
AllTypesMatch<["operand", "sin", "cos"]>]> {
521+
let summary = "sine and cosine of the specified value";
522+
let description = [{
523+
The `sincos` operation computes both the sine and cosine of a given value
524+
simultaneously. It takes one operand of floating point type (i.e., scalar,
525+
tensor or vector) and returns two results of the same type. This operation
526+
can be more efficient than computing sine and cosine separately when both
527+
values are needed.
528+
529+
Example:
530+
531+
```mlir
532+
// Scalar sine and cosine values.
533+
%sin, %cos = math.sincos %input : f64
534+
```
535+
}];
536+
537+
let arguments = (ins FloatLike:$operand,
538+
DefaultValuedAttr<Arith_FastMathAttr,
539+
"::mlir::arith::FastMathFlags::none">:$fastmath);
540+
let results = (outs FloatLike:$sin, FloatLike:$cos);
541+
542+
let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
543+
attr-dict `:` type($operand) }];
544+
545+
let extraClassDeclaration = [{
546+
std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
547+
}];
548+
}
549+
513550
//===----------------------------------------------------------------------===//
514551
// CountLeadingZerosOp
515552
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
436436
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
437437
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
438438
LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
439-
LLVM::SqrtOp>();
439+
LLVM::SincosOp, LLVM::SqrtOp>();
440440

441441
// TODO: Remove once we support replacing non-root ops.
442442
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
@@ -466,6 +466,100 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
466466
});
467467
}
468468

469+
struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
470+
using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
471+
472+
LogicalResult
473+
matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
474+
ConversionPatternRewriter &rewriter) const override {
475+
Location loc = op.getLoc();
476+
Value input = adaptor.getOperand();
477+
Type inputType = input.getType();
478+
auto convertedInput = maybeExt(input, rewriter);
479+
auto computeType = convertedInput.getType();
480+
481+
StringRef sincosFunc;
482+
if (isa<Float32Type>(computeType)) {
483+
const arith::FastMathFlags flag = op.getFastmath();
484+
const bool useApprox =
485+
mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
486+
sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
487+
} else if (isa<Float64Type>(computeType)) {
488+
sincosFunc = "__nv_sincos";
489+
} else {
490+
return rewriter.notifyMatchFailure(op,
491+
"unsupported operand type for sincos");
492+
}
493+
494+
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
495+
496+
Value sinPtr, cosPtr;
497+
{
498+
OpBuilder::InsertionGuard guard(rewriter);
499+
auto *scope =
500+
op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
501+
assert(scope && "Expected op to be inside automatic allocation scope");
502+
rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
503+
auto one = rewriter.create<LLVM::ConstantOp>(
504+
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
505+
sinPtr =
506+
rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
507+
cosPtr =
508+
rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
509+
}
510+
511+
createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
512+
op);
513+
514+
auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
515+
auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
516+
517+
rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
518+
maybeTrunc(cosResult, inputType, rewriter)});
519+
return success();
520+
}
521+
522+
private:
523+
Value maybeExt(Value operand, PatternRewriter &rewriter) const {
524+
if (isa<Float16Type, BFloat16Type>(operand.getType()))
525+
return rewriter.create<LLVM::FPExtOp>(
526+
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
527+
return operand;
528+
}
529+
530+
Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
531+
if (operand.getType() != type)
532+
return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand);
533+
return operand;
534+
}
535+
536+
void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
537+
StringRef funcName, Value input, Value sinPtr,
538+
Value cosPtr, Operation *op) const {
539+
auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
540+
auto ptrType = sinPtr.getType();
541+
542+
SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
543+
auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
544+
545+
auto funcAttr = StringAttr::get(op->getContext(), funcName);
546+
auto funcOp =
547+
SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
548+
549+
if (!funcOp) {
550+
auto parentFunc = op->getParentOfType<FunctionOpInterface>();
551+
assert(parentFunc && "expected there to be a parent function");
552+
OpBuilder b(parentFunc);
553+
554+
auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
555+
funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
556+
}
557+
558+
SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
559+
rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
560+
}
561+
};
562+
469563
template <typename OpTy>
470564
static void populateOpPatterns(const LLVMTypeConverter &converter,
471565
RewritePatternSet &patterns,
@@ -589,6 +683,9 @@ void mlir::populateLibDeviceConversionPatterns(
589683
"__nv_tan", "__nv_fast_tanf");
590684
populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
591685
"__nv_tanh");
686+
687+
// Custom pattern for sincos since it returns two values
688+
patterns.add<SincosOpLowering>(converter, benefit);
592689
}
593690

594691
void mlir::populateGpuToNVVMConversionPatterns(

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,38 @@ using CountTrailingZerosOpLowering =
121121
LLVM::CountTrailingZerosOp>;
122122
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
123123

124+
// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
125+
struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
126+
using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
127+
128+
LogicalResult
129+
matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
130+
ConversionPatternRewriter &rewriter) const override {
131+
const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
132+
mlir::Location loc = op.getLoc();
133+
mlir::Type operandType = adaptor.getOperand().getType();
134+
mlir::Type llvmOperandType = typeConverter.convertType(operandType);
135+
mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
136+
mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
137+
if (!llvmOperandType || !sinType || !cosType)
138+
return failure();
139+
140+
ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
141+
142+
auto structType = LLVM::LLVMStructType::getLiteral(
143+
rewriter.getContext(), {llvmOperandType, llvmOperandType});
144+
145+
auto sincosOp = rewriter.create<LLVM::SincosOp>(
146+
loc, structType, adaptor.getOperand(), attrs.getAttrs());
147+
148+
auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
149+
auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
150+
151+
rewriter.replaceOp(op, {sinValue, cosValue});
152+
return success();
153+
}
154+
};
155+
124156
// A `expm1` is converted into `exp - 1`.
125157
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
126158
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
@@ -393,6 +425,7 @@ void mlir::populateMathToLLVMConversionPatterns(
393425
RoundEvenOpLowering,
394426
RoundOpLowering,
395427
RsqrtOpLowering,
428+
SincosOpLowering,
396429
SinOpLowering,
397430
SinhOpLowering,
398431
ASinOpLowering,

mlir/lib/Dialect/Math/IR/MathOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,16 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
284284
});
285285
}
286286

287+
//===----------------------------------------------------------------------===//
288+
// SinCosOp getShapeForUnroll
289+
//===----------------------------------------------------------------------===//
290+
291+
std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
292+
if (auto vt = mlir::dyn_cast<VectorType>(getOperand().getType()))
293+
return llvm::to_vector<4>(vt.getShape());
294+
return std::nullopt;
295+
}
296+
287297
//===----------------------------------------------------------------------===//
288298
// CountLeadingZerosOp folder
289299
//===----------------------------------------------------------------------===//

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,3 +1109,42 @@ gpu.module @test_module_55 {
11091109
func.return %result32, %result64 : f32, f64
11101110
}
11111111
}
1112+
1113+
gpu.module @test_module_56 {
1114+
// CHECK: gpu.module @test_module_56
1115+
1116+
// CHECK-DAG: llvm.func @__nv_sincosf(f32, !llvm.ptr, !llvm.ptr)
1117+
// CHECK-DAG: llvm.func @__nv_sincos(f64, !llvm.ptr, !llvm.ptr)
1118+
1119+
// CHECK-LABEL: func @gpu_sincos
1120+
// CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
1121+
func.func @gpu_sincos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
1122+
// CHECK-COUNT-6: llvm.alloca
1123+
// CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
1124+
// CHECK: llvm.call @__nv_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
1125+
// CHECK-COUNT-2: llvm.fptrunc
1126+
// CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
1127+
// CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
1128+
%sin16, %cos16 = math.sincos %arg_f16 : f16
1129+
%sin32, %cos32 = math.sincos %arg_f32 : f32
1130+
%sin64, %cos64 = math.sincos %arg_f64 : f64
1131+
func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
1132+
}
1133+
1134+
// CHECK: llvm.func @__nv_fast_sincosf(f32, !llvm.ptr, !llvm.ptr)
1135+
1136+
// CHECK-LABEL: func @gpu_sincos_fastmath
1137+
// CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
1138+
func.func @gpu_sincos_fastmath(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
1139+
// CHECK-COUNT-6: llvm.alloca
1140+
// CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
1141+
// CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
1142+
// CHECK-COUNT-2: llvm.fptrunc
1143+
// CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
1144+
// CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
1145+
%sin16, %cos16 = math.sincos %arg_f16 fastmath<afn> : f16
1146+
%sin32, %cos32 = math.sincos %arg_f32 fastmath<afn> : f32
1147+
%sin64, %cos64 = math.sincos %arg_f64 fastmath<afn> : f64
1148+
func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
1149+
}
1150+
}

mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,16 @@ func.func @trigonometrics(%arg0: f32) {
230230

231231
// -----
232232

233+
// CHECK-LABEL: func @sincos
234+
// CHECK-SAME: [[ARG0:%.+]]: f32
235+
func.func @sincos(%arg0: f32) {
236+
// CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)>
237+
%0:2 = math.sincos %arg0 : f32
238+
func.return
239+
}
240+
241+
// -----
242+
233243
// CHECK-LABEL: func @inverse_trigonometrics
234244
// CHECK-SAME: [[ARG0:%.+]]: f32
235245
func.func @inverse_trigonometrics(%arg0: f32) {

mlir/test/Dialect/Math/ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
6262
return
6363
}
6464

65+
// CHECK-LABEL: func @sincos(
66+
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
67+
func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
68+
// CHECK: %{{.*}} = math.sincos %[[F]] : f32
69+
%0:2 = math.sincos %f : f32
70+
// CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32>
71+
%1:2 = math.sincos %v : vector<4xf32>
72+
// CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32>
73+
%2:2 = math.sincos %t : tensor<4x4x?xf32>
74+
return
75+
}
76+
6577
// CHECK-LABEL: func @erf(
6678
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
6779
func.func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {

0 commit comments

Comments
 (0)