Skip to content

Commit 7305ed7

Browse files
authored
[ROCDL] added math instructions to the ROCDL dialect (#169672)
Exposed llvm amdgcn math intrinsic calls through ROCDL
1 parent c6e23ab commit 7305ed7

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,33 @@ def ROCDL_FMed3Op : ROCDL_IntrOp<"fmed3", [0], [], [Pure, AllTypesMatch<["res",
19131913
}];
19141914
}
19151915

1916+
//===----------------------------------------------------------------------===//
1917+
// Math operations
1918+
//===----------------------------------------------------------------------===//
1919+
1920+
class ROCDL_Math_IntrOp<string mnemonic, list<Trait> traits = [Pure]> :
1921+
ROCDL_IntrOp<mnemonic, [0], [], traits, 1>,
1922+
Arguments<(ins LLVM_AnyFloat:$arg)> {
1923+
let results = (outs LLVM_AnyFloat:$res);
1924+
let description = [{
1925+
Note: In the general case, prefer the conventional `arith`, `math`, or `llvm` ops over this.
1926+
Use this ROCDL-specific operation only when you fully understand its implication and
1927+
when it is strictly necessary. This op is usually chosen when a small loss in precision is
1928+
acceptable in exchange for higher execution speed.
1929+
}];
1930+
let assemblyFormat =
1931+
"$arg qualified(type($arg)) attr-dict `->` qualified(type($res))";
1932+
}
1933+
1934+
def ROCDLTanh : ROCDL_Math_IntrOp<"tanh">;
1935+
def ROCDLSin : ROCDL_Math_IntrOp<"sin">;
1936+
def ROCDLCos : ROCDL_Math_IntrOp<"cos">;
1937+
def ROCDLRcp : ROCDL_Math_IntrOp<"rcp">;
1938+
def ROCDLExp : ROCDL_Math_IntrOp<"exp">;
1939+
def ROCDLExp2 : ROCDL_Math_IntrOp<"exp2">;
1940+
def ROCDLLog : ROCDL_Math_IntrOp<"log">;
1941+
def ROCDLSqrt : ROCDL_Math_IntrOp<"sqrt">;
1942+
19161943
//===----------------------------------------------------------------------===//
19171944
// ROCDL target attribute.
19181945
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,59 @@ func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4
4949
llvm.return %0 : vector<4xf16>
5050
}
5151

52+
func.func @rocdl.math.ops(%a: f32, %b: f16, %c: bf16) {
53+
// CHECK-LABEL: rocdl.math.ops
54+
// CHECK: %{{.*}} = rocdl.tanh %{{.*}} f32 -> f32
55+
// CHECK: %{{.*}} = rocdl.tanh %{{.*}} f16 -> f16
56+
// CHECK: %{{.*}} = rocdl.tanh %{{.*}} bf16 -> bf16
57+
%tanh0 = rocdl.tanh %a f32 -> f32
58+
%tanh1 = rocdl.tanh %b f16 -> f16
59+
%tanh2 = rocdl.tanh %c bf16 -> bf16
60+
61+
// CHECK: %{{.*}} = rocdl.sin %{{.*}} f32 -> f32
62+
// CHECK: %{{.*}} = rocdl.sin %{{.*}} f16 -> f16
63+
// CHECK: %{{.*}} = rocdl.sin %{{.*}} bf16 -> bf16
64+
%sin0 = rocdl.sin %a f32 -> f32
65+
%sin1 = rocdl.sin %b f16 -> f16
66+
%sin2 = rocdl.sin %c bf16 -> bf16
67+
68+
// CHECK: %{{.*}} = rocdl.cos %{{.*}} f32 -> f32
69+
// CHECK: %{{.*}} = rocdl.cos %{{.*}} f16 -> f16
70+
// CHECK: %{{.*}} = rocdl.cos %{{.*}} bf16 -> bf16
71+
%cos0 = rocdl.cos %a f32 -> f32
72+
%cos1 = rocdl.cos %b f16 -> f16
73+
%cos2 = rocdl.cos %c bf16 -> bf16
74+
75+
// CHECK: %{{.*}} = rocdl.rcp %{{.*}} f32 -> f32
76+
// CHECK: %{{.*}} = rocdl.rcp %{{.*}} f16 -> f16
77+
// CHECK: %{{.*}} = rocdl.rcp %{{.*}} bf16 -> bf16
78+
%rcp0 = rocdl.rcp %a f32 -> f32
79+
%rcp1 = rocdl.rcp %b f16 -> f16
80+
%rcp2 = rocdl.rcp %c bf16 -> bf16
81+
82+
// CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f32 -> f32
83+
// CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f16 -> f16
84+
// CHECK: %{{.*}} = rocdl.exp2 %{{.*}} bf16 -> bf16
85+
%exp2_0 = rocdl.exp2 %a f32 -> f32
86+
%exp2_1 = rocdl.exp2 %b f16 -> f16
87+
%exp2_2 = rocdl.exp2 %c bf16 -> bf16
88+
89+
// CHECK: %{{.*}} = rocdl.log %{{.*}} f32 -> f32
90+
// CHECK: %{{.*}} = rocdl.log %{{.*}} f16 -> f16
91+
// CHECK: %{{.*}} = rocdl.log %{{.*}} bf16 -> bf16
92+
%log0 = rocdl.log %a f32 -> f32
93+
%log1 = rocdl.log %b f16 -> f16
94+
%log2 = rocdl.log %c bf16 -> bf16
95+
96+
// CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f32 -> f32
97+
// CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f16 -> f16
98+
// CHECK: %{{.*}} = rocdl.sqrt %{{.*}} bf16 -> bf16
99+
%sqrt0 = rocdl.sqrt %a f32 -> f32
100+
%sqrt1 = rocdl.sqrt %b f16 -> f16
101+
%sqrt2 = rocdl.sqrt %c bf16 -> bf16
102+
llvm.return
103+
}
104+
52105
func.func @rocdl.barrier() {
53106
// CHECK: rocdl.barrier
54107
rocdl.barrier

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,59 @@ llvm.func @kernel_func_workgroups()
6161
llvm.return
6262
}
6363

64+
llvm.func @kernel_math_ops(%a: f32, %b: f16, %c: bf16) {
65+
// CHECK-LABEL: kernel_math_ops
66+
// CHECK: call float @llvm.amdgcn.tanh.f32(float %{{.*}})
67+
// CHECK: call half @llvm.amdgcn.tanh.f16(half %{{.*}})
68+
// CHECK: call bfloat @llvm.amdgcn.tanh.bf16(bfloat %{{.*}})
69+
%tanh0 = rocdl.tanh %a f32 -> f32
70+
%tanh1 = rocdl.tanh %b f16 -> f16
71+
%tanh2 = rocdl.tanh %c bf16 -> bf16
72+
73+
// CHECK: call float @llvm.amdgcn.sin.f32(float %{{.*}})
74+
// CHECK: call half @llvm.amdgcn.sin.f16(half %{{.*}})
75+
// CHECK: call bfloat @llvm.amdgcn.sin.bf16(bfloat %{{.*}})
76+
%sin0 = rocdl.sin %a f32 -> f32
77+
%sin1 = rocdl.sin %b f16 -> f16
78+
%sin2 = rocdl.sin %c bf16 -> bf16
79+
80+
// CHECK: call float @llvm.amdgcn.cos.f32(float %{{.*}})
81+
// CHECK: call half @llvm.amdgcn.cos.f16(half %{{.*}})
82+
// CHECK: call bfloat @llvm.amdgcn.cos.bf16(bfloat %{{.*}})
83+
%cos0 = rocdl.cos %a f32 -> f32
84+
%cos1 = rocdl.cos %b f16 -> f16
85+
%cos2 = rocdl.cos %c bf16 -> bf16
86+
87+
// CHECK: call float @llvm.amdgcn.rcp.f32(float %{{.*}})
88+
// CHECK: call half @llvm.amdgcn.rcp.f16(half %{{.*}})
89+
// CHECK: call bfloat @llvm.amdgcn.rcp.bf16(bfloat %{{.*}})
90+
%rcp0 = rocdl.rcp %a f32 -> f32
91+
%rcp1 = rocdl.rcp %b f16 -> f16
92+
%rcp2 = rocdl.rcp %c bf16 -> bf16
93+
94+
// CHECK: call float @llvm.amdgcn.exp2.f32(float %{{.*}})
95+
// CHECK: call half @llvm.amdgcn.exp2.f16(half %{{.*}})
96+
// CHECK: call bfloat @llvm.amdgcn.exp2.bf16(bfloat %{{.*}})
97+
%exp2_0 = rocdl.exp2 %a f32 -> f32
98+
%exp2_1 = rocdl.exp2 %b f16 -> f16
99+
%exp2_2 = rocdl.exp2 %c bf16 -> bf16
100+
101+
// CHECK: call float @llvm.amdgcn.log.f32(float %{{.*}})
102+
// CHECK: call half @llvm.amdgcn.log.f16(half %{{.*}})
103+
// CHECK: call bfloat @llvm.amdgcn.log.bf16(bfloat %{{.*}})
104+
%log0 = rocdl.log %a f32 -> f32
105+
%log1 = rocdl.log %b f16 -> f16
106+
%log2 = rocdl.log %c bf16 -> bf16
107+
108+
// CHECK: call float @llvm.amdgcn.sqrt.f32(float %{{.*}})
109+
// CHECK: call half @llvm.amdgcn.sqrt.f16(half %{{.*}})
110+
// CHECK: call bfloat @llvm.amdgcn.sqrt.bf16(bfloat %{{.*}})
111+
%sqrt0 = rocdl.sqrt %a f32 -> f32
112+
%sqrt1 = rocdl.sqrt %b f16 -> f16
113+
%sqrt2 = rocdl.sqrt %c bf16 -> bf16
114+
llvm.return
115+
}
116+
64117
llvm.func @known_block_sizes()
65118
attributes {rocdl.kernel,
66119
rocdl.flat_work_group_size = "128,128",

0 commit comments

Comments
 (0)