Skip to content

Commit 3a13167

Browse files
committed
[DirectX] Add WaveActiveOp builtin
- create int_dx_wave_active_op in IntrinsicsDirectX.td - add mapping to dxil op in DXIL.td - add scalarization to DirectXTargetTransformInfo.cpp - add tests of lowerings to dxil ops for both scalar and vector values
1 parent 23309d7 commit 3a13167

File tree

5 files changed

+102
-0
lines changed

5 files changed

+102
-0
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
8383
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
8484
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
8585
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
86+
def int_dx_wave_active_op : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i8_ty, llvm_i8_ty], [IntrConvergent, IntrNoMem]>;
8687
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8788
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
8889
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,16 @@ def CreateHandleFromBinding : DXILOp<218, createHandleFromBinding> {
793793
let stages = [Stages<DXIL1_6, [all_stages]>];
794794
}
795795

796+
def WaveActiveOp : DXILOp<119, waveActiveOp> {
797+
let Doc = "returns the result of the operation across waves";
798+
let LLVMIntrinsic = int_dx_wave_active_op;
799+
let arguments = [OverloadTy, Int8Ty, Int8Ty];
800+
let result = OverloadTy;
801+
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int16Ty, Int32Ty, Int64Ty]>];
802+
let stages = [Stages<DXIL1_0, [all_stages]>];
803+
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
804+
}
805+
796806
def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
797807
let Doc = "returns 1 for the first lane in the wave";
798808
let LLVMIntrinsic = int_dx_wave_is_first_lane;

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ using namespace llvm;
1818
bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
1919
unsigned ScalarOpdIdx) {
2020
switch (ID) {
21+
case Intrinsic::dx_wave_active_op: {
22+
return ScalarOpdIdx == 1 || ScalarOpdIdx == 2;
23+
}
2124
default:
2225
return false;
2326
}
@@ -26,6 +29,7 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
2629
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
2730
Intrinsic::ID ID) const {
2831
switch (ID) {
32+
case Intrinsic::dx_wave_active_op:
2933
case Intrinsic::dx_frac:
3034
case Intrinsic::dx_rsqrt:
3135
return true;
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
; Test that for scalar values, WaveReadLaneAt maps down to the DirectX op
4+
5+
define noundef <2 x half> @wave_active_op_v2half(<2 x half> noundef %expr) {
6+
entry:
7+
; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i0, i8 0, i8 0)
8+
; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i1, i8 0, i8 0)
9+
%ret = call <2 x half> @llvm.dx.wave.active.op.f16(<2 x half> %expr, i8 0, i8 0)
10+
ret <2 x half> %ret
11+
}
12+
13+
define noundef <3 x i32> @wave_active_op_v3i32(<3 x i32> noundef %expr) {
14+
entry:
15+
; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i0, i8 1, i8 1)
16+
; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i1, i8 1, i8 1)
17+
; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i2, i8 1, i8 1)
18+
%ret = call <3 x i32> @llvm.dx.wave.active.op(<3 x i32> %expr, i8 1, i8 1)
19+
ret <3 x i32> %ret
20+
}
21+
22+
define noundef <4 x double> @wave_active_op_v4f64(<4 x double> noundef %expr) {
23+
entry:
24+
; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i0, i8 2, i8 0)
25+
; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i1, i8 2, i8 0)
26+
; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i2, i8 2, i8 0)
27+
; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i3, i8 2, i8 0)
28+
%ret = call <4 x double> @llvm.dx.wave.active.op(<4 x double> %expr, i8 2, i8 0)
29+
ret <4 x double> %ret
30+
}
31+
32+
declare <2 x half> @llvm.dx.wave.active.op.v2f16(<2 x half>, i8, i8)
33+
declare <3 x i32> @llvm.dx.wave.active.op.v3i32(<3 x i32>, i8, i8)
34+
declare <4 x double> @llvm.dx.wave.active.op.v4f64(<4 x double>, i8, i8)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
; Test that for scalar values, WaveReadLaneAt maps down to the DirectX op
4+
5+
define noundef half @wave_active_op_half(half noundef %expr) {
6+
entry:
7+
; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr, i8 0, i8 0)
8+
%ret = call half @llvm.dx.wave.active.op.f16(half %expr, i8 0, i8 0)
9+
ret half %ret
10+
}
11+
12+
define noundef float @wave_active_op_float(float noundef %expr) {
13+
entry:
14+
; CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float %expr, i8 1, i8 0)
15+
%ret = call float @llvm.dx.wave.active.op(float %expr, i8 1, i8 0)
16+
ret float %ret
17+
}
18+
19+
define noundef double @wave_active_op_double(double noundef %expr) {
20+
entry:
21+
; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr, i8 2, i8 0)
22+
%ret = call double @llvm.dx.wave.active.op(double %expr, i8 2, i8 0)
23+
ret double %ret
24+
}
25+
26+
define noundef i16 @wave_active_op_i16(i16 noundef %expr) {
27+
entry:
28+
; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr, i8 1, i8 0)
29+
%ret = call i16 @llvm.dx.wave.active.op.i16(i16 %expr, i8 1, i8 0)
30+
ret i16 %ret
31+
}
32+
33+
define noundef i32 @wave_active_op_i32(i32 noundef %expr) {
34+
entry:
35+
; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr, i8 2, i8 1)
36+
%ret = call i32 @llvm.dx.wave.active.op.i32(i32 %expr, i8 2, i8 1)
37+
ret i32 %ret
38+
}
39+
40+
define noundef i64 @wave_active_op_i64(i64 noundef %expr) {
41+
entry:
42+
; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr, i8 3, i8 0)
43+
%ret = call i64 @llvm.dx.wave.active.op.i64(i64 %expr, i8 3, i8 0)
44+
ret i64 %ret
45+
}
46+
47+
declare half @llvm.dx.wave.active.op.f16(half, i8, i8)
48+
declare float @llvm.dx.wave.active.op.f32(float, i8, i8)
49+
declare double @llvm.dx.wave.active.op.f64(double, i8, i8)
50+
51+
declare i16 @llvm.dx.wave.active.op.i16(i16, i8, i8)
52+
declare i32 @llvm.dx.wave.active.op.i32(i32, i8, i8)
53+
declare i64 @llvm.dx.wave.active.op.i64(i64, i8, i8)

0 commit comments

Comments
 (0)