Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18827,9 +18827,21 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
{}, false, true));
// Since we don't define a SPIR-V intrinsic for the SPIR-V built-in from
// SPIRVBuiltins.td, manually get the matching name for the DirectX
// intrinsic and the demangled builtin name
switch (CGM.getTarget().getTriple().getArch()) {
case llvm::Triple::dxil:
return EmitRuntimeCall(Intrinsic::getDeclaration(
&CGM.getModule(), Intrinsic::dx_waveGetLaneIndex));
case llvm::Triple::spirv:
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
llvm::FunctionType::get(IntTy, {}, false),
"__hlsl_wave_get_lane_index", {}, false, true));
default:
llvm_unreachable(
"Intrinsic WaveGetLaneIndex not supported by target architecture");
}
}
case Builtin::BI__builtin_hlsl_wave_is_first_lane: {
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1956,6 +1956,11 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
if (SemaRef.checkArgCount(TheCall, 0))
return true;
break;
}
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
Expand Down
20 changes: 14 additions & 6 deletions clang/test/CodeGenHLSL/builtins/wave_get_lane_index_simple.hlsl
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
// RUN: spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,CHECK-SPIRV
// RUN: %clang_cc1 -finclude-default-header \
// RUN: -triple dxil-pc-shadermodel6.3-library %s \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,CHECK-DXIL

// CHECK: define spir_func noundef i32 @_Z6test_1v() [[A0:#[0-9]+]] {
// CHECK: %[[CI:[0-9]+]] = call token @llvm.experimental.convergence.entry()
// CHECK: call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %[[CI]]) ]
uint test_1() {
// CHECK-SPIRV: define spir_func noundef i32 @{{.*test_1.*}}() [[A0:#[0-9]+]] {
// CHECK-DXIL: define noundef i32 @{{.*test_1.*}}() [[A0:#[0-9]+]] {
// CHECK-SPIRV: %[[CI:[0-9]+]] = call token @llvm.experimental.convergence.entry()
// CHECK-SPIRV: call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %[[CI]]) ]
// CHECK-DXIL: call i32 @llvm.dx.waveGetLaneIndex()
int test_1() {
return WaveGetLaneIndex();
}

// CHECK: declare i32 @__hlsl_wave_get_lane_index() [[A1:#[0-9]+]]
// CHECK-SPIRV: declare i32 @__hlsl_wave_get_lane_index() [[A1:#[0-9]+]]
// CHECK-DXIL: declare i32 @llvm.dx.waveGetLaneIndex() [[A1:#[0-9]+]]

// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
// CHECK-DAG: attributes [[A1]] = { {{.*}}convergent{{.*}} }
6 changes: 6 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveGetLaneIndex-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected

int test_too_many_arg(int x) {
return __builtin_hlsl_wave_get_lane_index(x);
// expected-error@-1 {{too many arguments to function call, expected 0, have 1}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_waveGetLaneIndex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,12 @@ def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
let Doc = "returns the index of the current lane in the wave";
let LLVMIntrinsic = int_dx_waveGetLaneIndex;
let arguments = [];
let result = Int32Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/WaveGetLaneIndex.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s

define void @main() {
entry:
; CHECK: call i32 @dx.op.waveGetLaneIndex(i32 111)
%0 = call i32 @llvm.dx.waveGetLaneIndex()
ret void
}

declare i32 @llvm.dx.waveGetLaneIndex()
Loading