Skip to content

Commit 922ba57

Browse files
authored
Allow noinline functions to be called with correct argument types. (#3963)
This PR addresses issues with noinline functions to ensure they can be called with correct argument types. The changes focus on simplifying function signatures and removing unnecessary attributes and metadata. --------- Signed-off-by: Tiotto, Ettore <[email protected]> Signed-off-by: Ettore Tiotto <[email protected]>
1 parent fb2fff8 commit 922ba57

File tree

3 files changed

+19
-16
lines changed

3 files changed

+19
-16
lines changed

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
382382
torch.testing.assert_close(expect, actual)
383383

384384

385-
@triton.jit(noinline=False)
385+
@triton.jit(noinline=True)
386386
def tensor_descriptor_in_function_helper(out_ptr, in_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
387387
in_desc = tl.make_tensor_descriptor(
388388
in_ptr,

test/TritonIntelGPU/tritonintelgpu-rewrite-stack-ptr.mlir

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm --tritonintelgpu-rewrite-stack-ptr | FileCheck %s
22

3-
module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} {
4-
// CHECK-LABEL: llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
5-
// CHECK-LABEL: llvm.func spir_kernelcc @kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: !llvm.ptr<1>)
6-
tt.func public @kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
3+
module attributes {"ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32} {
4+
// CHECK-LABEL: llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
5+
// CHECK: llvm.func spir_kernelcc @kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, [[GLOBAL_PTR:%.*]]: !llvm.ptr<1>, [[PROFILE_PTR:%.*]]: !llvm.ptr<1>)
6+
tt.func public @kernel(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) {
77
%0 = tt.load %arg0 : !tt.ptr<f32>
88
%1 = tt.load %arg1 : !tt.ptr<f32>
99
// CHECK: [[LOAD0:%.*]] = llvm.extractelement {{.*}}[{{.*}}]
@@ -13,8 +13,8 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
1313
tt.call @noinline_simple_fn__fp32_fp32_Pfp32__(%0, %1, %arg2) : (f32, f32, !tt.ptr<f32>) -> ()
1414
tt.return
1515
}
16-
// CHECK: llvm.func internal spir_funccc @noinline_simple_fn__fp32_fp32_Pfp32__(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>, %arg5: !llvm.ptr<1>)
17-
tt.func private @noinline_simple_fn__fp32_fp32_Pfp32__(%arg0: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg1: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg2: !tt.ptr<f32> {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 16 : i64}) attributes {noinline = true} {
16+
// CHECK: llvm.func internal spir_funccc @noinline_simple_fn__fp32_fp32_Pfp32__(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>, %arg5: !llvm.ptr<1>)
17+
tt.func private @noinline_simple_fn__fp32_fp32_Pfp32__(%arg0: f32, %arg1: f32, %arg2: !tt.ptr<f32>) attributes {noinline = true} {
1818
%0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
1919
tt.store %arg2, %0 : !tt.ptr<f32>
2020
tt.return
@@ -27,21 +27,21 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
2727
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
2828
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
2929
#smem = #ttg.shared_memory
30-
module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 1280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
31-
// CHECK-LABEL: llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
32-
// CHECK-LABEL: llvm.func spir_kernelcc @kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: !llvm.ptr<1>, %arg5: !llvm.ptr<3>)
33-
tt.func public @kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
30+
module attributes {"ttg.num-warps" = 1 : i32, ttg.shared = 1280 : i32, "ttg.threads-per-warp" = 16 : i32} {
31+
// CHECK-LABEL: llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
32+
// CHECK: llvm.func spir_kernelcc @kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, [[GLOBAL_PTR:%.*]]: !llvm.ptr<1>, [[PROFILE_PTR:%.*]]: !llvm.ptr<1>, [[SHARED_MEM_PTR:%.*]]: !llvm.ptr<3>)
33+
tt.func public @kernel(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) {
3434
%0 = tt.load %arg0 : !tt.ptr<f32>
3535
%1 = tt.load %arg1 : !tt.ptr<f32>
3636
// CHECK: [[LOAD0:%.*]] = llvm.extractelement {{.*}}[{{.*}}]
3737
// CHECK: [[LOAD1:%.*]] = llvm.extractelement {{.*}}[{{.*}}]
38-
// CHECK: llvm.call spir_funccc @noinline_shared_fn__fp32_fp32_Pfp32__([[LOAD0]], [[LOAD1]], %arg2, %arg5, %arg3, %arg4)
39-
tt.call @noinline_shared_fn__fp32_fp32_Pfp32__(%0, %1, %arg2) {allocation.offset = 0 : i32} : (f32, f32, !tt.ptr<f32>) -> ()
38+
// CHECK: llvm.call spir_funccc @noinline_shared_fn([[LOAD0]], [[LOAD1]], %arg2, [[SHARED_MEM_PTR]], [[GLOBAL_PTR]], [[PROFILE_PTR]])
39+
tt.call @noinline_shared_fn(%0, %1, %arg2) {allocation.offset = 0 : i32} : (f32, f32, !tt.ptr<f32>) -> ()
4040
tt.return
4141
}
42-
// CHECK: llvm.func internal spir_funccc @noinline_shared_fn__fp32_fp32_Pfp32__(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>, %arg5: !llvm.ptr<1>)
42+
// CHECK: llvm.func internal spir_funccc @noinline_shared_fn(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>, %arg5: !llvm.ptr<1>)
4343
// CHECK: llvm.getelementptr %arg3[{{.*}}]
44-
tt.func private @noinline_shared_fn__fp32_fp32_Pfp32__(%arg0: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg1: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg2: !tt.ptr<f32> {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 16 : i64}) attributes {noinline = true} {
44+
tt.func private @noinline_shared_fn(%arg0: f32, %arg1: f32, %arg2: !tt.ptr<f32>) attributes {noinline = true} {
4545
%cst = arith.constant dense<16> : tensor<16x1xi32, #blocked>
4646
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
4747
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked>

third_party/intel/backend/driver.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,10 @@ def format_of(ty):
609609
cgh.set_arg(index, *static_cast<const T *>(value));
610610
}}
611611
612-
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr, void* global_scratch, void* profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
612+
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ,
613+
int num_warps, int threads_per_warp, int shared_memory,
614+
sycl::queue& stream, sycl::kernel& kernel_ptr,
615+
void* global_scratch, void* profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
613616
614617
std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
615618
{ 'RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {});' if COMPILATION_HELPER.inject_pytorch_dep else "" }

0 commit comments

Comments
 (0)