Skip to content

Commit 6c16d3b

Browse files
authored
fix: update global variable syntax since its not support from triton 3.2.0 (#148)
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 1a8ee37 commit 6c16d3b

File tree

4 files changed

+6
-4
lines changed

4 files changed

+6
-4
lines changed

plugins/attention-and-distributed-packing/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers=[
2323
"Programming Language :: Python :: 3.11",
2424
]
2525

26-
dependencies = ["numba", "trl"]
26+
dependencies = ["numba", "trl>=0.19.1,<0.20.0"]
2727

2828
[tool.hatch.build.targets.wheel]
2929
only-include = ["src/fms_acceleration_aadp"]

plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from fms_acceleration import AccelerationPlugin
2121
from peft import LoraConfig
2222
from transformers import DataCollatorForSeq2Seq, TrainingArguments
23-
from trl import DataCollatorForCompletionOnlyLM # pylint: disable=import-error
23+
from trl import ( # pylint: disable=import-error, no-name-in-module
24+
DataCollatorForCompletionOnlyLM,
25+
)
2426
import torch
2527

2628

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def _cross_entropy_backward(
187187
pass
188188

189189

190-
MAX_FUSED_SIZE: tl.constexpr = 65536 # 2**16
190+
MAX_FUSED_SIZE = tl.constexpr(65536) # 2**16
191191

192192
class Fast_CrossEntropyLoss(torch.autograd.Function):
193193
@staticmethod

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from .utils import calculate_settings
1919

20-
ROPE_GROUP_SIZE: tl.constexpr = 4
20+
ROPE_GROUP_SIZE = tl.constexpr(4)
2121

2222
@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
2323
@triton.jit

0 commit comments

Comments
 (0)