Skip to content

Commit ff86d26

Browse files
authored
[Frontend] Use full module and qualname when mangling non-kernel function names (#7025)
Without this, calling two different Python jit functions with the same name will cause a name collision.
1 parent 16a87b4 commit ff86d26

File tree

6 files changed

+66
-31
lines changed

6 files changed

+66
-31
lines changed

python/test/gluon/test_frontend.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,25 +307,25 @@ def anchor(x):
307307
@filecheck_test
308308
@gluon.jit
309309
def test_warp_specialize():
310-
# CHECK-LABEL: tt.func public @test_warp_specialize
310+
# CHECK-LABEL: test_warp_specialize
311311
# CHECK-NEXT: [[A:%.*]] = tt.make_range {end = 1 : i32, start = 0 : i32}
312312
# CHECK-NEXT: [[B:%.*]] = tt.make_range {end = 2 : i32, start = 0 : i32}
313313
# CHECK-NEXT: [[C:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
314314
# CHECK-NEXT: [[OUTS:%.*]]:3 = ttg.warp_specialize([[A]], [[B]], [[C]]) {{.*}}requestedRegisters = array<i32: 24, 48>
315315
# CHECK-NEXT: default {
316-
# CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @"warp_specialize_default{{.*}}"([[A]], [[B]], [[C]])
316+
# CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}}warp_specialize_default{{.*}}([[A]], [[B]], [[C]])
317317
# CHECK-NEXT: warp_yield [[RESULTS]]#0, [[RESULTS]]#1, [[RESULTS]]#2
318318
# CHECK-NEXT: }
319319
# CHECK-NEXT: partition0(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
320-
# CHECK-NEXT: call @"warp_specialize_worker0{{.*}}"(%arg0, %arg1, %arg2)
320+
# CHECK-NEXT: call @{{.*}}warp_specialize_worker0{{.*}}(%arg0, %arg1, %arg2)
321321
# CHECK-NEXT: warp_return
322322
# CHECK-NEXT: }
323323
# CHECK-NEXT: partition1(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
324-
# CHECK-NEXT: call @"warp_specialize_worker1{{.*}}"(%arg0, %arg1, %arg2)
324+
# CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}(%arg0, %arg1, %arg2)
325325
# CHECK-NEXT: warp_return
326326
# CHECK-NEXT: }
327-
# CHECK-NEXT: call @anchor{{.*}}([[OUTS]]#0)
328-
# CHECK-NEXT: call @"anchor{{.*}}"([[OUTS]]#1, [[OUTS]]#2)
327+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#0)
328+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#1, [[OUTS]]#2)
329329
pair = Pair(tl.arange(0, 1), tl.arange(0, 2))
330330
a, b = ttgl.warp_specialize((pair, tl.arange(0, 4)), warp_specialize_default,
331331
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
@@ -584,10 +584,10 @@ def kernel():
584584
module {
585585
tt.func public @kernel() attributes {noinline = false} {
586586
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
587-
tt.call @"smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
587+
tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
588588
tt.return
589589
}
590-
tt.func private @"smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
590+
tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
591591
tt.return
592592
}
593593
}

python/test/unit/language/test_frontend.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_assign_attribute():
4242
scalar = 11
4343
pair = Pair(tl.arange(0, 4), scalar)
4444
# CHECK: %c42_i32 = arith.constant 42 : i32
45-
# CHECK-NEXT: call @"anchor{{.*}}"([[RANGE]], %c42_i32)
45+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[RANGE]], %c42_i32)
4646
pair.second = 42
4747
anchor(pair)
4848

@@ -58,7 +58,7 @@ def test_augassign_attribute():
5858
# CHECK: %c42_i32 = arith.constant 42 : i32
5959
# CHECK: [[VALUE:%.*]] = arith.addi %c11_i32, %c42_i32
6060
pair.second += 42
61-
# CHECK-NEXT: call @"anchor{{.*}}"([[RANGE]], [[VALUE]])
61+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[RANGE]], [[VALUE]])
6262
anchor(pair)
6363

6464

@@ -69,12 +69,12 @@ def test_jit_method():
6969
# CHECK: %c11_i32 = arith.constant 11 : i32
7070
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
7171
scalar = 11
72-
# CHECK: [[V:%.*]]:2 = tt.call @"unpack{{.*}}"([[RANGE]], %c11_i32)
72+
# CHECK: [[V:%.*]]:2 = tt.call @{{.*}}unpack{{.*}}([[RANGE]], %c11_i32)
7373
pair = Pair(tl.arange(0, 4), scalar)
7474
a, b = pair.unpack()
75-
# CHECK: call @anchor{{.*}}([[V]]#0)
75+
# CHECK: call @{{.*}}anchor{{.*}}([[V]]#0)
7676
anchor(a)
77-
# CHECK: call @anchor{{.*}}([[V]]#1)
77+
# CHECK: call @{{.*}}anchor{{.*}}([[V]]#1)
7878
anchor(b)
7979

8080

@@ -95,10 +95,10 @@ def test_aggregate_initializers():
9595
# CHECK-LABEL: test_aggregate_initializers
9696
value = TypeWithBuiltinInitializer()
9797
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
98-
# CHECK: call @"anchor{{.*}}"([[RANGE]])
98+
# CHECK: call @{{.*}}anchor{{.*}}([[RANGE]])
9999
anchor(value)
100100
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 8 : i32, start = 4 : i32}
101-
# CHECK: call @"anchor{{.*}}"([[RANGE]])
101+
# CHECK: call @{{.*}}anchor{{.*}}([[RANGE]])
102102
value.modify(tl.arange(4, 8))
103103
anchor(value)
104104

@@ -118,11 +118,11 @@ def list_of_functions_constexpr(arg, fns: tl.constexpr):
118118
@triton.jit
119119
def test_list_of_functions():
120120
# CHECK-LABEL: test_list_of_functions
121-
# CHECK: call @"list_of_functions_constexpr{{.*}}cJITFunction(test_frontend:anchor){{.*}}cJITFunction(test_frontend:forward)"
121+
# CHECK: call @{{.*}}list_of_functions_constexpr{{.*}}cJITFunction(test_frontend:anchor){{.*}}cJITFunction(test_frontend:forward)
122122

123-
# CHECK-LABEL: tt.func private @"list_of_functions_constexpr
124-
# CHECK-NEXT: call @anchor
125-
# CHECK-NEXT: call @forward
123+
# CHECK: tt.func private @{{.*}}list_of_functions_constexpr
124+
# CHECK-NEXT: call @{{.*}}anchor
125+
# CHECK-NEXT: call @{{.*}}forward
126126
list_of_functions_constexpr(tl.arange(0, 4), [anchor, forward])
127127

128128

@@ -138,11 +138,34 @@ def test_call_in_loop():
138138
# CHECK-LABEL: test_call_in_loop
139139
acc = 0
140140
# CHECK: scf.for
141-
# CHECK: call @accumulate
141+
# CHECK: call @{{.*}}accumulate
142142
for i in range(10):
143143
acc = accumulate(acc, i)
144144

145145

146+
@tl.core._aggregate
147+
class FunctionParent:
148+
149+
@triton.jit
150+
def function_with_name():
151+
pass
152+
153+
154+
@triton.jit
155+
def function_with_name():
156+
pass
157+
158+
159+
@filecheck_test
160+
@triton.jit
161+
def test_function_name_mangling():
162+
# CHECK-LABEL: test_function_name_mangling
163+
# CHECK: call @test_frontend.function_with_name
164+
# CHECK: call @test_frontend.FunctionParent.function_with_name
165+
function_with_name()
166+
FunctionParent.function_with_name()
167+
168+
146169
@tl.core._aggregate
147170
class AggregateWithConstexpr:
148171
a: tl.tensor
@@ -166,10 +189,10 @@ def add_rhs_constexpr(agg):
166189
@triton.jit
167190
def test_aggregate_with_constexpr():
168191
# CHECK-LABEL: test_aggregate_with_constexpr
169-
# CHECK: tt.call @"add_rhs_constexpr__test_frontend.AggregateWithConstexpr<i32S4S, constexpr[42]>
192+
# CHECK: tt.call @"test_frontend.add_rhs_constexpr__test_frontend.AggregateWithConstexpr<i32S4S, constexpr[42]>
170193
agg = AggregateWithConstexpr.create(tl.arange(0, 4))
171194
add_rhs_constexpr(agg)
172195

173-
# CHECK: tt.func private @"add_rhs_constexpr__test_frontend.AggregateWithConstexpr<i32S4S, constexpr[42]>
196+
# CHECK: tt.func private @"test_frontend.add_rhs_constexpr__test_frontend.AggregateWithConstexpr<i32S4S, constexpr[42]>
174197
# CHECK: %cst = arith.constant dense<42> : tensor<4xi32>
175198
# CHECK: arith.addi %arg0, %cst : tensor<4xi32>
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import triton
2+
3+
4+
@triton.jit
5+
def function_with_name():
6+
pass

python/test/unit/test_filecheck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_kernel():
1717
# CHECK-LABEL: test_kernel
1818
scalar = 42
1919
# CHECK: %c42_i32 = arith.constant 42 : i32
20-
# CHECK-NEXT: call @anchor{{.*}}(%c42_i32) : (i32) -> ()
20+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}(%c42_i32) : (i32) -> ()
2121
anchor(scalar)
2222

2323
run_filecheck_test(test_kernel)

python/triton/compiler/code_generator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .._C.libtriton import ir, gluon_ir
1414
from ..language import constexpr, semantic, str_to_ty, tensor
1515
from ..language.core import _unwrap_if_constexpr, base_value, base_type
16-
from ..runtime.jit import get_jit_fn_file_line
16+
from ..runtime.jit import get_jit_fn_file_line, get_full_name
1717
# ideally we wouldn't need any runtime component
1818
from ..runtime import JITFunction
1919
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
@@ -315,6 +315,7 @@ def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunctio
315315
self.jit_fn = jit_fn
316316
# TODO: we currently generate illegal names for non-kernel functions involving constexprs!
317317
if is_kernel:
318+
function_name = function_name[function_name.rfind('.') + 1:]
318319
function_name = check_identifier_legality(function_name, "function")
319320
self.function_name = function_name
320321
self.is_kernel = is_kernel
@@ -1200,7 +1201,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
12001201
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
12011202
args_val = [get_iterable_path(args, path) for path in args_path]
12021203
# mangle
1203-
fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst)
1204+
fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst)
12041205
# generate function def if necessary
12051206
if not self.module.has_function(fn_name):
12061207
gscope = fn.__globals__

python/triton/runtime/jit.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,10 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
485485
type_canonicalisation_dict[v] = v
486486

487487

488+
def get_full_name(fn):
489+
return f"{fn.__module__}.{fn.__qualname__}"
490+
491+
488492
@dataclass
489493
class JitFunctionInfo:
490494
module: ModuleType
@@ -511,7 +515,7 @@ def _call_hook(
511515
if not hook:
512516
return None
513517

514-
name = self.fn.__name__
518+
name = get_full_name(self.fn)
515519
module = self.fn.__module__
516520
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
517521
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"
@@ -653,7 +657,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
653657
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
654658
self.starting_line_number = inspect.getsourcelines(fn)[1]
655659
self._repr = repr
656-
self._fn_name = fn.__name__
660+
self._fn_name = get_full_name(fn)
657661
self.launch_metadata = launch_metadata
658662

659663
self.params = []
@@ -698,14 +702,15 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
698702
# reuse docs of wrapped function
699703
self.__doc__ = fn.__doc__
700704
self.__name__ = fn.__name__
705+
self.__qualname__ = fn.__qualname__
701706
self.__globals__ = fn.__globals__
702707
self.__module__ = fn.__module__
703708

704709
@property
705710
def cache_key(self):
706711
# TODO : hash should be attribute of `self`
707712
if self.hash is None:
708-
dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src)
713+
dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, src=self.src)
709714
dependencies_finder.visit(self.parse())
710715
self.hash = dependencies_finder.ret + str(self.starting_line_number)
711716
self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
@@ -725,9 +730,9 @@ def preload(self, specialization_data):
725730
import triton.language as tl
726731
device = driver.active.get_current_device()
727732
deserialized_obj = json.loads(specialization_data)
728-
if deserialized_obj['name'] != self.fn.__name__:
733+
if deserialized_obj['name'] != self._fn_name:
729734
raise RuntimeError(
730-
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}")
735+
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}")
731736
constant_keys = map(tuple, deserialized_obj['constant_keys'])
732737
constant_vals = deserialized_obj['constant_vals']
733738
constants = {
@@ -778,7 +783,7 @@ def _unsafe_update_src(self, new_src):
778783
super().__setattr__('src', new_src)
779784

780785
def __repr__(self):
781-
return f"JITFunction({self.module}:{self.fn.__name__})"
786+
return f"JITFunction({self.module}:{self.fn.__qualname__})"
782787

783788

784789
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)