@@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
436
436
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
437
437
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
438
438
LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
439
- LLVM::SqrtOp>();
439
+ LLVM::SincosOp, LLVM:: SqrtOp>();
440
440
441
441
// TODO: Remove once we support replacing non-root ops.
442
442
target.addLegalOp <gpu::YieldOp, gpu::GPUModuleOp>();
@@ -466,6 +466,100 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
466
466
});
467
467
}
468
468
469
+ struct SincosOpLowering : public ConvertOpToLLVMPattern <math::SincosOp> {
470
+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
471
+
472
+ LogicalResult
473
+ matchAndRewrite (math::SincosOp op, OpAdaptor adaptor,
474
+ ConversionPatternRewriter &rewriter) const override {
475
+ Location loc = op.getLoc ();
476
+ Value input = adaptor.getOperand ();
477
+ Type inputType = input.getType ();
478
+ auto convertedInput = maybeExt (input, rewriter);
479
+ auto computeType = convertedInput.getType ();
480
+
481
+ StringRef sincosFunc;
482
+ if (isa<Float32Type>(computeType)) {
483
+ const arith::FastMathFlags flag = op.getFastmath ();
484
+ const bool useApprox =
485
+ mlir::arith::bitEnumContainsAny (flag, arith::FastMathFlags::afn);
486
+ sincosFunc = useApprox ? " __nv_fast_sincosf" : " __nv_sincosf" ;
487
+ } else if (isa<Float64Type>(computeType)) {
488
+ sincosFunc = " __nv_sincos" ;
489
+ } else {
490
+ return rewriter.notifyMatchFailure (op,
491
+ " unsupported operand type for sincos" );
492
+ }
493
+
494
+ auto ptrType = LLVM::LLVMPointerType::get (rewriter.getContext ());
495
+
496
+ Value sinPtr, cosPtr;
497
+ {
498
+ OpBuilder::InsertionGuard guard (rewriter);
499
+ auto *scope =
500
+ op->getParentWithTrait <mlir::OpTrait::AutomaticAllocationScope>();
501
+ assert (scope && " Expected op to be inside automatic allocation scope" );
502
+ rewriter.setInsertionPointToStart (&scope->getRegion (0 ).front ());
503
+ auto one = rewriter.create <LLVM::ConstantOp>(
504
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (1 ));
505
+ sinPtr =
506
+ rewriter.create <LLVM::AllocaOp>(loc, ptrType, computeType, one, 0 );
507
+ cosPtr =
508
+ rewriter.create <LLVM::AllocaOp>(loc, ptrType, computeType, one, 0 );
509
+ }
510
+
511
+ createSincosCall (rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
512
+ op);
513
+
514
+ auto sinResult = rewriter.create <LLVM::LoadOp>(loc, computeType, sinPtr);
515
+ auto cosResult = rewriter.create <LLVM::LoadOp>(loc, computeType, cosPtr);
516
+
517
+ rewriter.replaceOp (op, {maybeTrunc (sinResult, inputType, rewriter),
518
+ maybeTrunc (cosResult, inputType, rewriter)});
519
+ return success ();
520
+ }
521
+
522
+ private:
523
+ Value maybeExt (Value operand, PatternRewriter &rewriter) const {
524
+ if (isa<Float16Type, BFloat16Type>(operand.getType ()))
525
+ return rewriter.create <LLVM::FPExtOp>(
526
+ operand.getLoc (), Float32Type::get (rewriter.getContext ()), operand);
527
+ return operand;
528
+ }
529
+
530
+ Value maybeTrunc (Value operand, Type type, PatternRewriter &rewriter) const {
531
+ if (operand.getType () != type)
532
+ return rewriter.create <LLVM::FPTruncOp>(operand.getLoc (), type, operand);
533
+ return operand;
534
+ }
535
+
536
+ void createSincosCall (ConversionPatternRewriter &rewriter, Location loc,
537
+ StringRef funcName, Value input, Value sinPtr,
538
+ Value cosPtr, Operation *op) const {
539
+ auto voidType = LLVM::LLVMVoidType::get (rewriter.getContext ());
540
+ auto ptrType = sinPtr.getType ();
541
+
542
+ SmallVector<Type> operandTypes = {input.getType (), ptrType, ptrType};
543
+ auto funcType = LLVM::LLVMFunctionType::get (voidType, operandTypes);
544
+
545
+ auto funcAttr = StringAttr::get (op->getContext (), funcName);
546
+ auto funcOp =
547
+ SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
548
+
549
+ if (!funcOp) {
550
+ auto parentFunc = op->getParentOfType <FunctionOpInterface>();
551
+ assert (parentFunc && " expected there to be a parent function" );
552
+ OpBuilder b (parentFunc);
553
+
554
+ auto globalloc = loc->findInstanceOfOrUnknown <FileLineColLoc>();
555
+ funcOp = LLVM::LLVMFuncOp::create (b, globalloc, funcName, funcType);
556
+ }
557
+
558
+ SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
559
+ rewriter.create <LLVM::CallOp>(loc, funcOp, callOperands);
560
+ }
561
+ };
562
+
469
563
template <typename OpTy>
470
564
static void populateOpPatterns (const LLVMTypeConverter &converter,
471
565
RewritePatternSet &patterns,
@@ -589,6 +683,9 @@ void mlir::populateLibDeviceConversionPatterns(
589
683
" __nv_tan" , " __nv_fast_tanf" );
590
684
populateOpPatterns<math::TanhOp>(converter, patterns, benefit, " __nv_tanhf" ,
591
685
" __nv_tanh" );
686
+
687
+ // Custom pattern for sincos since it returns two values
688
+ patterns.add <SincosOpLowering>(converter, benefit);
592
689
}
593
690
594
691
void mlir::populateGpuToNVVMConversionPatterns (
0 commit comments