diff --git a/plugins/attention-and-distributed-packing/pyproject.toml b/plugins/attention-and-distributed-packing/pyproject.toml index e755ac56..fe6fa5d8 100644 --- a/plugins/attention-and-distributed-packing/pyproject.toml +++ b/plugins/attention-and-distributed-packing/pyproject.toml @@ -23,7 +23,7 @@ classifiers=[ "Programming Language :: Python :: 3.11", ] -dependencies = ["numba", "trl"] +dependencies = ["numba", "trl>=0.19.1,<0.20.0"] [tool.hatch.build.targets.wheel] only-include = ["src/fms_acceleration_aadp"] diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py index 596b5600..67c75369 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py @@ -20,7 +20,9 @@ from fms_acceleration import AccelerationPlugin from peft import LoraConfig from transformers import DataCollatorForSeq2Seq, TrainingArguments -from trl import DataCollatorForCompletionOnlyLM # pylint: disable=import-error +from trl import ( # pylint: disable=import-error, no-name-in-module + DataCollatorForCompletionOnlyLM, +) import torch diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py index 3da07310..365c695d 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py @@ -187,7 +187,7 @@ def _cross_entropy_backward( pass -MAX_FUSED_SIZE: tl.constexpr = 65536 # 2**16 +MAX_FUSED_SIZE = tl.constexpr(65536) # 2**16 class Fast_CrossEntropyLoss(torch.autograd.Function): @staticmethod diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py index c97c8cfc..4345f4e3 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py @@ -17,7 +17,7 @@ import torch from .utils import calculate_settings -ROPE_GROUP_SIZE: tl.constexpr = 4 +ROPE_GROUP_SIZE = tl.constexpr(4) @triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],}) @triton.jit