Skip to content

Commit 2426ba7

Browse files
Fix warp num lookup segfault for empty modules (#4322)
There's a subtle problem with the pattern application for the conversion to llvm pass I noticed recently. The symptom is that when I have two modules that are lowered to llvm correctly separately, and I put them into a single test case, the matcher fails to find legalization patterns. While trying to create a minimal reproducer, I came across this issue that if I just add an empty module, our conversion fails with a segfault. This small fix also helps with the initial problem I had (although I'm not exactly sure why just yet).
1 parent d7a6c9c commit 2426ba7

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

test/Conversion/intel/tritonintelgpu_to_llvm.mlir

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
// RUN: triton-opt %s --convert-triton-intel-gpu-to-llvm | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s
22

33
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}>
44
module attributes { "ttg.threads-per-warp" = 16 : i32, "ttg.num-warps" = 4 : i32 } {
55
// As the assert message is shared, a single instance is emitted.
66

77
// CHECK-DAG: llvm.mlir.global internal constant @assertFunc_("unknown\00") {addr_space = 1 : i32}
8-
// CHECK-DAG: llvm.mlir.global internal constant @assertFile_("{{.*}}tritonintelgpu_to_llvm.mlir\00") {addr_space = 1 : i32}
8+
// CHECK-DAG: llvm.mlir.global internal constant @assertFile_("{{.*}}tritonintelgpu_to_llvm.mlir{{.*}}\00") {addr_space = 1 : i32}
99
// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_("assert text\00") {addr_space = 1 : i32}
1010
// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_3("different assert text\00") {addr_space = 1 : i32}
1111
// CHECK-DAG: llvm.func spir_funccc @__assert_fail(!llvm.ptr<4>, !llvm.ptr<4>, i32, !llvm.ptr<4>)
@@ -84,3 +84,8 @@ module attributes { "ttg.threads-per-warp" = 16 : i32, "ttg.num-warps" = 4 : i32
8484
tt.return
8585
}
8686
}
87+
88+
// -----
89+
90+
// Sanity check for the conversion pass to correctly process even empty modules
91+
module attributes { "ttg.threads-per-warp" = 16 : i32, "ttg.num-warps" = 4 : i32 } {}

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ struct ConvertTritonGPUToLLVM
101101
TritonIntelGPUToLLVMTypeConverter typeConverter(
102102
context, option, *targetInfo, isAdvancedPathEnabled);
103103
TritonLLVMConversionTarget convTarget(*context);
104-
int numWarps = triton::gpu::lookupNumWarps(&*mod.getOps().begin());
104+
int numWarps = triton::gpu::lookupNumWarps(mod);
105105
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
106106
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
107107

0 commit comments

Comments
 (0)