Skip to content

Commit 0ca15d8

Browse files
authored
[Gluon] Fix tensor.sum (#7617)
Currently `CodeGenerator` uses `fn.is_gluon()` to choose which semantic & builder to use, but this runs into issues when trying to share code between triton and gluon. Instead, this uses the `fn.is_gluon` from the top level function and propagates it through the call graph regardless of which annotation is used on the function being called.
1 parent e8c4711 commit 0ca15d8

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

python/test/gluon/test_frontend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def reduce_kernel(out):
948948
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0])
949949
a = ttgl.full([16, 16], 1, ttgl.float32, layout)
950950
b = ttgl.full([16, 16], 2, ttgl.float32, layout)
951-
s0 = ttgl.sum(a, 0)
951+
s0 = a.sum(0)
952952
ttgl.static_assert(s0.type.layout == ttgl.SliceLayout(0, layout))
953953
s1 = ttgl.sum(a, 1)
954954
ttgl.static_assert(s1.type.layout == ttgl.SliceLayout(1, layout))

python/triton/compiler/code_generator.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,12 @@ class BoundJITMethod:
294294

295295
class CodeGenerator(ast.NodeVisitor):
296296

297-
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map,
298-
module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False,
299-
caller_context=None, file_name: Optional[str] = None, begin_line=0):
297+
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, *, options, codegen_fns,
298+
module_map, is_gluon, module=None, is_kernel=False, function_types: Optional[Dict] = None,
299+
noinline=False, caller_context=None, file_name: Optional[str] = None, begin_line=0):
300300
self.context = context
301-
if jit_fn.is_gluon():
301+
self.is_gluon = is_gluon
302+
if is_gluon:
302303
from triton.experimental.gluon.language._semantic import GluonSemantic
303304
self.builder = gluon_ir.GluonOpBuilder(context)
304305
self.semantic = GluonSemantic(self.builder)
@@ -1253,7 +1254,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
12531254
function_name=fn_name, function_types=self.function_ret_types,
12541255
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
12551256
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
1256-
module_map=self.builder.module_map, caller_context=caller_context)
1257+
module_map=self.builder.module_map, caller_context=caller_context,
1258+
is_gluon=self.is_gluon)
12571259
try:
12581260
generator.visit(fn.parse())
12591261
except Exception as e:
@@ -1527,7 +1529,7 @@ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None)
15271529
proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
15281530
generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
15291531
jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
1530-
codegen_fns=codegen_fns, module_map=module_map, module=module)
1532+
codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
15311533
generator.visit(fn.parse())
15321534
ret = generator.module
15331535
# module takes ownership of the context

0 commit comments

Comments
 (0)