Skip to content

Commit 47cf9e3

Browse files
fix: update datatype to int (#150)
Signed-off-by: yashasvi <[email protected]>
1 parent 903c23e commit 47cf9e3

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def _cross_entropy_backward(
187187
pass
188188

189189

190-
MAX_FUSED_SIZE = tl.constexpr(65536) # 2**16
190+
MAX_FUSED_SIZE = 65536 # 2**16
191191

192192
class Fast_CrossEntropyLoss(torch.autograd.Function):
193193
@staticmethod

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from .utils import calculate_settings
1919

20-
ROPE_GROUP_SIZE = tl.constexpr(4)
20+
ROPE_GROUP_SIZE : int = 4
2121

2222
@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
2323
@triton.jit
@@ -36,6 +36,7 @@ def _rope_embedding(
3636
RoPE is Q * cos + rotate_half(Q) * sin
3737
See our blog post for more info
3838
"""
39+
ROPE_GROUP_SIZE = 4
3940
row_position = tl.program_id(0)
4041
group_head_position = tl.program_id(1)
4142
col_offsets = tl.arange(0, BLOCK_SIZE)

0 commit comments

Comments
 (0)