Skip to content

Commit 158473c

Browse files
davidberard98liuyunqi20
authored andcommitted
[FRONTEND] Fix handling of from m import x as y in CodeGenerator (#5081)
Context: in `CodeGenerator.__init__`, globals for a given triton function are modified to handle remapping the libdevice module to cuda or hip (from triton-lang/triton#4539). In particular, this logic: ```python for k, v in gscope.items(): # gscope is a dict of fn.__globals__ ... self.gscope[k] = getattr(module_map[module_name], k) ``` was failing if you do this in the global scope: `from triton.language.extras.libdevice import fast_dividef as my_fast_dividef`.
1 parent 410e65e commit 158473c

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
from triton.language.extra.libdevice import fast_dividef as my_fast_dividef
7+
8+
9+
def test_libdevice_rename(device):
10+
# mark the import as used by this test
11+
_ = my_fast_dividef
12+
13+
@triton.jit
14+
def triton_copy(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
15+
offsets = tl.arange(0, BLOCK_SIZE)
16+
data = tl.load(in_ptr + offsets)
17+
tl.store(out_ptr + offsets, data)
18+
19+
BLOCK_SIZE = 256
20+
inp = torch.randn(BLOCK_SIZE, device=device)
21+
out = torch.empty_like(inp)
22+
23+
triton_copy[(1, )](inp, out, BLOCK_SIZE)

python/triton/compiler/code_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
218218

219219
module_name = getattr(v, "__module__", "")
220220
if module_name in module_map:
221-
self.gscope[k] = getattr(module_map[module_name], k)
221+
self.gscope[k] = getattr(module_map[module_name], v.__name__)
222222
else:
223223
self.gscope[k] = v
224224

0 commit comments

Comments
 (0)