diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py index 365c695d..59a55cee 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py @@ -187,7 +187,7 @@ def _cross_entropy_backward( pass -MAX_FUSED_SIZE = tl.constexpr(65536) # 2**16 +MAX_FUSED_SIZE = 65536 # 2**16 class Fast_CrossEntropyLoss(torch.autograd.Function): @staticmethod diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py index 4345f4e3..456ec100 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py @@ -17,7 +17,7 @@ import torch from .utils import calculate_settings -ROPE_GROUP_SIZE = tl.constexpr(4) +ROPE_GROUP_SIZE : int = 4 @triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],}) @triton.jit @@ -36,6 +36,7 @@ def _rope_embedding( RoPE is Q * cos + rotate_half(Q) * sin See our blog post for more info """ + ROPE_GROUP_SIZE = 4 row_position = tl.program_id(0) group_head_position = tl.program_id(1) col_offsets = tl.arange(0, BLOCK_SIZE)