Skip to content

Commit a0e6610

Browse files
Merge commit '04f87d021a3550aa536862aecee21bf0a30a2452'
2 parents 532728c + 04f87d0 commit a0e6610

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

python/test/unit/language/test_libdevice.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import triton
55
import triton.language as tl
66
from triton.language.extra.intel import libdevice
7+
from triton.language.extra.libdevice import fast_dividef as my_fast_dividef
78

89

910
@pytest.mark.parametrize("dtype_str", ["float32", "float64"])
@@ -39,3 +40,20 @@ def kernel(in_p, out_p, fn: tl.constexpr, SIZE: tl.constexpr):
3940
kernel[(1, )](x, y_exp, fn=libdevice_fn, SIZE=SIZE, num_warps=4, num_ctas=1)
4041

4142
torch.testing.assert_close(y_ref, y_exp, equal_nan=True)
43+
44+
45+
def test_libdevice_rename(device):
46+
# mark the import as used by this test
47+
_ = my_fast_dividef
48+
49+
@triton.jit
50+
def triton_copy(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
51+
offsets = tl.arange(0, BLOCK_SIZE)
52+
data = tl.load(in_ptr + offsets)
53+
tl.store(out_ptr + offsets, data)
54+
55+
BLOCK_SIZE = 256
56+
inp = torch.randn(BLOCK_SIZE, device=device)
57+
out = torch.empty_like(inp)
58+
59+
triton_copy[(1, )](inp, out, BLOCK_SIZE)

python/triton/compiler/code_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
218218

219219
module_name = getattr(v, "__module__", "")
220220
if module_name in module_map:
221-
self.gscope[k] = getattr(module_map[module_name], k)
221+
self.gscope[k] = getattr(module_map[module_name], v.__name__)
222222
else:
223223
self.gscope[k] = v
224224

0 commit comments

Comments
 (0)