Skip to content

Conversation

@kmehant
Copy link
Collaborator

@kmehant kmehant commented Sep 6, 2025

triton>3.2.0 does not support annotation based syntax for global variables.

This resulted in errors with our stack such as

  File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 83, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 33:39:
    sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
                   half_head_dim*0 + col_offsets, mask = mask, other = 0)
    cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
                   half_head_dim*0 + col_offsets, mask = mask, other = 0)
    if BACKWARD_PASS:
        # See our blog post for more info.
        sin1 = -sin1
    pass
    # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
    head_start = group_head_position * ROPE_GROUP_SIZE
                                       ^
NameError("Cannot access global variable ROPE_GROUP_SIZE from within @jit'ed function. Triton kernels can only access global variables that are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported.  Alternatively, set the envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not promise to support this forever.")

FYA - @HarikrishnanBalagopal @YashasviChaurasia @ashokponkumar

@fabianlim

@kmehant kmehant requested a review from fabianlim as a code owner September 6, 2025 12:34
@ashokponkumar ashokponkumar merged commit 6c16d3b into foundation-model-stack:main Sep 8, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants