Skip to content

Commit 5ffd73b

Browse files
[MLIR][NVVM] Add log2.approx.f32 intrinsic with FTZ variant
1 parent e9c5740 commit 5ffd73b

File tree

3 files changed

+25
-21
lines changed

3 files changed

+25
-21
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,30 +453,33 @@ def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [Pure], 1> {
453453
// NVVM lg2 approximate op
454454
//===----------------------------------------------------------------------===//
455455

456-
def NVVM_Lg2ApproxF32Op : NVVM_Op<"lg2.approx.f", [Pure]> {
456+
def NVVM_Log2ApproxF32Op : NVVM_Op<"log2.approx.f", [Pure]> {
457457

458458
let summary = "Compute approximate base-2 log (lg2.approx.f32)";
459459

460460
let description = [{
461461
Compute the approximate base-2 logarithm of the input value.
462462
If 'ftz' is true, subnormal numbers are flushed to zero.
463+
464+
[PTX ISA Reference for log2.approx]
465+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-lg2)
466+
463467
}];
464468

465469
let arguments = (ins
466-
F32:$arg,
467-
OptionalAttr<BoolAttr>:$ftz);
470+
F32:$input,
471+
DefaultValuedAttr<BoolAttr, "false">:$ftz);
468472

469473
let results = (outs F32:$res);
470474

471-
let assemblyFormat = "$arg attr-dict `:` type($res)";
475+
let assemblyFormat = "$input attr-dict `:` type($res)";
472476

473477
string llvmBuilder = [{
474-
bool ftz = $ftz && *$ftz;
475478
llvm::Intrinsic::ID intrinsicID =
476-
ftz ? llvm::Intrinsic::nvvm_lg2_approx_ftz_f
479+
$ftz ? llvm::Intrinsic::nvvm_lg2_approx_ftz_f
477480
: llvm::Intrinsic::nvvm_lg2_approx_f;
478481

479-
$res = createIntrinsicCall(builder, intrinsicID, {$arg});
482+
$res = createIntrinsicCall(builder, intrinsicID, {$input});
480483
}];
481484
}
482485

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,6 @@ func.func @nvvm_rcp(%arg0: f32) -> f32 {
3636
llvm.return %0 : f32
3737
}
3838

39-
// CHECK-LABEL: @nvvm_lg2_approx_f
40-
func.func @nvvm_lg2_approx_f(%arg0: f32) -> f32 {
41-
// CHECK: nvvm.lg2.approx.f %arg0 : f32
42-
%0 = nvvm.lg2.approx.f %arg0 : f32
43-
llvm.return %0 : f32
44-
}
45-
46-
// CHECK-LABEL: @nvvm_lg2_approx_ftz_f
47-
func.func @nvvm_lg2_approx_ftz_f(%arg0: f32) -> f32 {
48-
// CHECK: nvvm.lg2.approx.f %arg0 {ftz = true} : f32
49-
%0 = nvvm.lg2.approx.f %arg0 {ftz = true} : f32
50-
llvm.return %0 : f32
51-
}
52-
5339
// CHECK-LABEL: @llvm_nvvm_barrier0
5440
func.func @llvm_nvvm_barrier0() {
5541
// CHECK: nvvm.barrier0
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: define float @nvvm_log2_approx_f(float %0)
4+
// CHECK: call float @llvm.nvvm.lg2.approx.f(float %0)
5+
llvm.func @nvvm_log2_approx_f(%arg0: f32) -> f32 {
6+
%0 = nvvm.log2.approx.f %arg0 : f32
7+
llvm.return %0 : f32
8+
}
9+
10+
// CHECK-LABEL: define float @nvvm_log2_approx_ftz_f(float %0)
11+
// CHECK: call float @llvm.nvvm.lg2.approx.ftz.f(float %0)
12+
llvm.func @nvvm_log2_approx_ftz_f(%arg0: f32) -> f32 {
13+
%0 = nvvm.log2.approx.f %arg0 {ftz = true} : f32
14+
llvm.return %0 : f32
15+
}

0 commit comments

Comments
 (0)