Skip to content

Commit 484b208

Browse files
committed
review comments
- remove unneeded attributes from test functions - provide context to parameters in CGBuiltin - update intrinsic naming for clarity
1 parent 473150d commit 484b208

File tree

5 files changed

+48
-53
lines changed

5 files changed

+48
-53
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18848,8 +18848,10 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1884818848
std::string name =
1884918849
Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
1885018850
ArrayRef{OpExpr->getType()}, &CGM.getModule());
18851-
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
18852-
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
18851+
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {},
18852+
/*Local=*/false,
18853+
/*AssumeConvergent=*/true),
18854+
ArrayRef{OpExpr, OpIndex}, "hlsl.waveReadLaneAt");
1885318855
}
1885418856
case Builtin::BI__builtin_hlsl_elementwise_sign: {
1885518857
Value *Op0 = EmitScalarExpr(E->getArg(0));

clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010
// CHECK-LABEL: test_int
1111
int test_int(int expr, uint idx) {
1212
// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
13-
1413
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok]]) ]
1514
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
16-
1715
// CHECK: ret [[TY]] %[[RET]]
1816
return WaveReadLaneAt(expr, idx);
1917
}
@@ -26,10 +24,8 @@ int test_int(int expr, uint idx) {
2624
// CHECK-LABEL: test_floatv4
2725
float4 test_floatv4(float4 expr, uint idx) {
2826
// CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
29-
3027
// CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok1]]) ]
3128
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
32-
3329
// CHECK: ret [[TY1]] %[[RET1]]
3430
return WaveReadLaneAt(expr, idx);
3531
}

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
244244
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
245245
MachineInstr &I) const;
246246

247+
bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
248+
MachineInstr &I) const;
249+
247250
bool selectStep(Register ResVReg, const SPIRVType *ResType,
248251
MachineInstr &I) const;
249252

@@ -1853,6 +1856,26 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
18531856
return Result;
18541857
}
18551858

1859+
bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
1860+
const SPIRVType *ResType,
1861+
MachineInstr &I) const {
1862+
assert(I.getNumOperands() == 4);
1863+
assert(I.getOperand(2).isReg());
1864+
assert(I.getOperand(3).isReg());
1865+
MachineBasicBlock &BB = *I.getParent();
1866+
1867+
// IntTy is used to define the execution scope, set to 3 to denote a
1868+
// cross-lane interaction equivalent to a SPIR-V subgroup.
1869+
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
1870+
return BuildMI(BB, I, I.getDebugLoc(),
1871+
TII.get(SPIRV::OpGroupNonUniformShuffle))
1872+
.addDef(ResVReg)
1873+
.addUse(GR.getSPIRVTypeID(ResType))
1874+
.addUse(I.getOperand(2).getReg())
1875+
.addUse(I.getOperand(3).getReg())
1876+
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
1877+
}
1878+
18561879
bool SPIRVInstructionSelector::selectStep(Register ResVReg,
18571880
const SPIRVType *ResType,
18581881
MachineInstr &I) const {
@@ -2653,22 +2676,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
26532676
.addUse(GR.getSPIRVTypeID(ResType))
26542677
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
26552678
}
2656-
case Intrinsic::spv_wave_read_lane_at: {
2657-
assert(I.getNumOperands() == 4);
2658-
assert(I.getOperand(2).isReg());
2659-
assert(I.getOperand(3).isReg());
2660-
2661-
// IntTy is used to define the execution scope, set to 3 to denote a
2662-
// cross-lane interaction equivalent to a SPIR-V subgroup.
2663-
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
2664-
return BuildMI(BB, I, I.getDebugLoc(),
2665-
TII.get(SPIRV::OpGroupNonUniformShuffle))
2666-
.addDef(ResVReg)
2667-
.addUse(GR.getSPIRVTypeID(ResType))
2668-
.addUse(I.getOperand(2).getReg())
2669-
.addUse(I.getOperand(3).getReg())
2670-
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
2671-
}
2679+
case Intrinsic::spv_wave_read_lane_at:
2680+
return selectWaveReadLaneAt(ResVReg, ResType, I);
26722681
case Intrinsic::spv_step:
26732682
return selectStep(ResVReg, ResType, I);
26742683
// Discard intrinsics which we do not expect to actually represent code after

llvm/test/CodeGen/DirectX/WaveReadLaneAt.ll

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,55 +2,52 @@
22

33
; Test that for scalar values, WaveReadLaneAt maps down to the DirectX op
44

5-
define noundef half @wave_rla_half(half noundef %expr, i32 noundef %idx) #0 {
5+
define noundef half @wave_rla_half(half noundef %expr, i32 noundef %idx) {
66
entry:
77
; CHECK: call half @dx.op.waveReadLaneAt.f16(i32 117, half %expr, i32 %idx)
88
%ret = call half @llvm.dx.wave.read.lane.at.f16(half %expr, i32 %idx)
99
ret half %ret
1010
}
1111

12-
define noundef float @wave_rla_float(float noundef %expr, i32 noundef %idx) #0 {
12+
define noundef float @wave_rla_float(float noundef %expr, i32 noundef %idx) {
1313
entry:
1414
; CHECK: call float @dx.op.waveReadLaneAt.f32(i32 117, float %expr, i32 %idx)
1515
%ret = call float @llvm.dx.wave.read.lane.at(float %expr, i32 %idx)
1616
ret float %ret
1717
}
1818

19-
define noundef double @wave_rla_double(double noundef %expr, i32 noundef %idx) #0 {
19+
define noundef double @wave_rla_double(double noundef %expr, i32 noundef %idx) {
2020
entry:
2121
; CHECK: call double @dx.op.waveReadLaneAt.f64(i32 117, double %expr, i32 %idx)
2222
%ret = call double @llvm.dx.wave.read.lane.at(double %expr, i32 %idx)
2323
ret double %ret
2424
}
2525

26-
define noundef i1 @wave_rla_i1(i1 noundef %expr, i32 noundef %idx) #0 {
26+
define noundef i1 @wave_rla_i1(i1 noundef %expr, i32 noundef %idx) {
2727
entry:
2828
; CHECK: call i1 @dx.op.waveReadLaneAt.i1(i32 117, i1 %expr, i32 %idx)
2929
%ret = call i1 @llvm.dx.wave.read.lane.at.i1(i1 %expr, i32 %idx)
3030
ret i1 %ret
3131
}
3232

33-
define noundef i16 @wave_rla_i16(i16 noundef %expr, i32 noundef %idx) #0 {
33+
define noundef i16 @wave_rla_i16(i16 noundef %expr, i32 noundef %idx) {
3434
entry:
3535
; CHECK: call i16 @dx.op.waveReadLaneAt.i16(i32 117, i16 %expr, i32 %idx)
3636
%ret = call i16 @llvm.dx.wave.read.lane.at.i16(i16 %expr, i32 %idx)
3737
ret i16 %ret
3838
}
3939

40-
define noundef i32 @wave_rla_i32(i32 noundef %expr, i32 noundef %idx) #0 {
40+
define noundef i32 @wave_rla_i32(i32 noundef %expr, i32 noundef %idx) {
4141
entry:
4242
; CHECK: call i32 @dx.op.waveReadLaneAt.i32(i32 117, i32 %expr, i32 %idx)
4343
%ret = call i32 @llvm.dx.wave.read.lane.at.i32(i32 %expr, i32 %idx)
4444
ret i32 %ret
4545
}
4646

47-
declare half @llvm.dx.wave.read.lane.at.f16(half, i32) #1
48-
declare float @llvm.dx.wave.read.lane.at.f32(float, i32) #1
49-
declare double @llvm.dx.wave.read.lane.at.f64(double, i32) #1
47+
declare half @llvm.dx.wave.read.lane.at.f16(half, i32)
48+
declare float @llvm.dx.wave.read.lane.at.f32(float, i32)
49+
declare double @llvm.dx.wave.read.lane.at.f64(double, i32)
5050

51-
declare i1 @llvm.dx.wave.read.lane.at.i1(i1, i32) #1
52-
declare i16 @llvm.dx.wave.read.lane.at.i16(i16, i32) #1
53-
declare i32 @llvm.dx.wave.read.lane.at.i32(i32, i32) #1
54-
55-
attributes #0 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
56-
attributes #1 = { nocallback nofree nosync nounwind willreturn }
51+
declare i1 @llvm.dx.wave.read.lane.at.i1(i1, i32)
52+
declare i16 @llvm.dx.wave.read.lane.at.i16(i16, i32)
53+
declare i32 @llvm.dx.wave.read.lane.at.i32(i32, i32)

llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,17 @@
1313
; CHECK-DAG: %[[#idx:]] = OpFunctionParameter %[[#uint]]
1414
; CHECK-DAG: %[[#vbexpr:]] = OpFunctionParameter %[[#v4_bool]]
1515

16-
define spir_func void @test_1(float %fexpr, i32 %iexpr, <4 x i1> %vbexpr, i32 %idx) #0 {
16+
define spir_func void @test_1(float %fexpr, i32 %iexpr, <4 x i1> %vbexpr, i32 %idx) {
1717
entry:
18-
%0 = call token @llvm.experimental.convergence.entry()
1918
; CHECK: %[[#fret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#fexpr]] %[[#idx]] %[[#scope]]
20-
%1 = call float @llvm.spv.wave.read.lane.at.f32(float %fexpr, i32 %idx) [ "convergencectrl"(token %0) ]
19+
%0 = call float @llvm.spv.wave.read.lane.at.f32(float %fexpr, i32 %idx)
2120
; CHECK: %[[#iret:]] = OpGroupNonUniformShuffle %[[#uint]] %[[#iexpr]] %[[#idx]] %[[#scope]]
22-
%2 = call i32 @llvm.spv.wave.read.lane.at.i32(i32 %iexpr, i32 %idx) [ "convergencectrl"(token %0) ]
21+
%1 = call i32 @llvm.spv.wave.read.lane.at.i32(i32 %iexpr, i32 %idx)
2322
; CHECK: %[[#vbret:]] = OpGroupNonUniformShuffle %[[#v4_bool]] %[[#vbexpr]] %[[#idx]] %[[#scope]]
24-
%3 = call <4 x i1> @llvm.spv.wave.read.lane.at.v4i1(<4 x i1> %vbexpr, i32 %idx) [ "convergencectrl"(token %0) ]
23+
%2 = call <4 x i1> @llvm.spv.wave.read.lane.at.v4i1(<4 x i1> %vbexpr, i32 %idx)
2524
ret void
2625
}
2726

28-
declare float @__hlsl_wave_read_lane_at.f32(float, i32) #1
29-
declare i32 @__hlsl_wave_read_lane_at.i32(i32, i32) #1
30-
declare <4 x i1> @__hlsl_wave_read_lane_at.v4i1(<4 x i1>, i32) #1
31-
32-
attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
33-
attributes #1 = { convergent }
34-
35-
!llvm.module.flags = !{!0, !1}
36-
37-
!0 = !{i32 1, !"wchar_size", i32 4}
38-
!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
27+
declare float @__hlsl_wave_read_lane_at.f32(float, i32)
28+
declare i32 @__hlsl_wave_read_lane_at.i32(i32, i32)
29+
declare <4 x i1> @__hlsl_wave_read_lane_at.v4i1(<4 x i1>, i32)

0 commit comments

Comments
 (0)