Skip to content

Commit ab25b49

Browse files
authored
[TIR] Fix InjectPTXLDG32 segfaults and skip non-CUDA targets (#18671)
### Motivation InjectPTXLDG32 rewrites BufferStore when encountering if_then_else, but it only initializes temporary buffers when an Allocate node exists. For functions without Allocate, this leads to uninitialized buffers and a hard segfault during compilation. In addition, the PTX-only pass can run on CPU/LLVM targets when tir.ptx_ldg32=1, injecting PTX intrinsics that are invalid for non-CUDA codegen. This PR ensures temporary buffers are created even when no Allocate exists, and skips InjectPTXLDG32 on non-CUDA targets, preventing segfaults and invalid PTX intrinsics on CPU. ### Changes - Ensure temp buffers are created when the rewrite path is taken without Allocate - Insert allocations at the function level when needed - Guard InjectPTXLDG32 so it only runs on CUDA targets - Add tests for CUDA (insertion) and CPU (skip) behavior ### Testing test_tir_transform_inject_ptx_ldg32.py ### Fixes - [#18612](#18612) - [#18617](#18617) - [#18599](#18599)
1 parent 66f7f37 commit ab25b49

File tree

2 files changed

+115
-9
lines changed

2 files changed

+115
-9
lines changed

src/tir/transforms/inject_ptx_ldg32.cc

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,22 @@ namespace tir {
3535

3636
class PTXRewriter : public StmtMutator {
3737
public:
38-
Stmt VisitStmt_(const AllocateNode* allocate) final {
39-
if (!has_buffer_1) {
40-
has_buffer_1 = true;
41-
// addr[0] -> global_addr / addr[1] -> local_addr
42-
addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, DataType::Int(32), "addr", "local");
43-
predicate_buffer =
44-
decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(), "predicate", "local");
38+
Stmt AddAllocationsIfNeeded(Stmt body) {
39+
if (!needs_buffer || has_buffer_2) {
40+
return body;
4541
}
42+
EnsureBuffers();
43+
body = Allocate(addr_buffer->data, addr_buffer->dtype, addr_buffer->shape, Bool(true), body);
44+
body = Allocate(predicate_buffer->data, predicate_buffer->dtype, predicate_buffer->shape,
45+
Bool(true), body);
46+
has_buffer_2 = true;
47+
return body;
48+
}
49+
50+
Stmt VisitStmt_(const AllocateNode* allocate) final {
4651
Stmt result = StmtMutator::VisitStmt_(allocate);
47-
if (!has_buffer_2) {
52+
if (needs_buffer && !has_buffer_2) {
53+
EnsureBuffers();
4854
has_buffer_2 = true;
4955
result =
5056
Allocate(addr_buffer->data, addr_buffer->dtype, addr_buffer->shape, Bool(true), result);
@@ -82,6 +88,8 @@ class PTXRewriter : public StmtMutator {
8288
if (ramp != nullptr) {
8389
return result;
8490
}
91+
EnsureBuffers();
92+
needs_buffer = true;
8593
local_addr = store->indices[0];
8694
BufferStore addr_store(addr_buffer, global_addr, {IntImm(DataType::Int(32), 0)});
8795
BufferStore local_addr_store(addr_buffer, local_addr, {IntImm(DataType::Int(32), 1)});
@@ -104,7 +112,19 @@ class PTXRewriter : public StmtMutator {
104112
return result;
105113
}
106114

115+
void EnsureBuffers() {
116+
if (has_buffer_1) {
117+
return;
118+
}
119+
has_buffer_1 = true;
120+
// addr[0] -> global_addr / addr[1] -> local_addr
121+
addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, DataType::Int(32), "addr", "local");
122+
predicate_buffer =
123+
decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(), "predicate", "local");
124+
}
125+
107126
bool has_buffer_1 = false, has_buffer_2 = false;
127+
bool needs_buffer = false;
108128
Buffer addr_buffer, predicate_buffer;
109129
};
110130

@@ -113,8 +133,14 @@ namespace transform {
113133
Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) {
114134
auto pass_func = [enable_inject_ptx_intrin](PrimFunc f, IRModule m, PassContext ctx) {
115135
if (enable_inject_ptx_intrin) {
136+
auto target = f->GetAttr<Target>("target");
137+
if (!target.defined() || target.value()->kind->name != "cuda") {
138+
return f;
139+
}
116140
auto* n = f.CopyOnWrite();
117-
n->body = PTXRewriter()(n->body);
141+
PTXRewriter rewriter;
142+
Stmt body = rewriter(n->body);
143+
n->body = rewriter.AddAllocationsIfNeeded(body);
118144
// inject ptx
119145
}
120146
return f;
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tvm
19+
import tvm.testing
20+
from tvm.script import tir as T
21+
22+
23+
def _count_alloc(stmt):
24+
num_alloc = [0]
25+
26+
def visit(n):
27+
if isinstance(n, tvm.tir.Allocate):
28+
num_alloc[0] += 1
29+
30+
tvm.tir.stmt_functor.post_order_visit(stmt, visit)
31+
return num_alloc[0]
32+
33+
34+
def _count_ptx_ldg32(stmt):
35+
num_call = [0]
36+
37+
def visit(n):
38+
if isinstance(n, tvm.tir.Call) and n.op.name == "tir.ptx_ldg32":
39+
num_call[0] += 1
40+
41+
tvm.tir.stmt_functor.post_order_visit(stmt, visit)
42+
return num_call[0]
43+
44+
45+
@T.prim_func
46+
def where_no_alloc(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "float32")) -> None:
47+
T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("cuda")})
48+
for i in range(4):
49+
C[i] = T.if_then_else(A[i] > T.float32(0), A[i], T.float32(0))
50+
51+
52+
@T.prim_func
53+
def where_no_alloc_cpu(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "float32")) -> None:
54+
T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")})
55+
for i in range(4):
56+
C[i] = T.if_then_else(A[i] > T.float32(0), A[i], T.float32(0))
57+
58+
59+
def test_inject_ptx_ldg32_inserts_alloc_for_no_alloc_func():
60+
mod = tvm.IRModule.from_expr(where_no_alloc)
61+
assert _count_alloc(mod["main"].body) == 0
62+
63+
mod = tvm.tir.transform.InjectPTXLDG32()(mod)
64+
assert _count_alloc(mod["main"].body) > 0
65+
assert _count_ptx_ldg32(mod["main"].body) == 1
66+
67+
68+
def test_inject_ptx_ldg32_skip_non_cuda_target():
69+
mod = tvm.IRModule.from_expr(where_no_alloc_cpu)
70+
cpu_target = tvm.target.Target("llvm")
71+
mod = tvm.IRModule({"main": mod["main"].with_attr("target", cpu_target)})
72+
assert _count_alloc(mod["main"].body) == 0
73+
74+
mod = tvm.tir.transform.InjectPTXLDG32()(mod)
75+
assert _count_alloc(mod["main"].body) == 0
76+
assert _count_ptx_ldg32(mod["main"].body) == 0
77+
78+
79+
if __name__ == "__main__":
80+
tvm.testing.main()

0 commit comments

Comments
 (0)