From 3837bc9916a6c158e2e8bccce426370347aa1bb4 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 20 Nov 2025 11:51:05 +0800 Subject: [PATCH 1/4] support dpo orpo and simpo --- kt-sft/ktransformers/dpo/__init__.py | 0 kt-sft/ktransformers/dpo/trainer.py | 201 +++++++++++++++++++++ kt-sft/ktransformers/local_chat.py | 8 +- kt-sft/ktransformers/operators/linear.py | 2 +- kt-sft/ktransformers/sft/lora.py | 18 +- kt-sft/ktransformers/util/__init__.py | 0 kt-sft/ktransformers/util/grad_wrapper.py | 2 +- kt-sft/ktransformers/util/trainer_utils.py | 20 ++ kt-sft/ktransformers/util/utils.py | 2 +- 9 files changed, 229 insertions(+), 24 deletions(-) create mode 100644 kt-sft/ktransformers/dpo/__init__.py create mode 100644 kt-sft/ktransformers/dpo/trainer.py create mode 100644 kt-sft/ktransformers/util/__init__.py create mode 100644 kt-sft/ktransformers/util/trainer_utils.py diff --git a/kt-sft/ktransformers/dpo/__init__.py b/kt-sft/ktransformers/dpo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kt-sft/ktransformers/dpo/trainer.py b/kt-sft/ktransformers/dpo/trainer.py new file mode 100644 index 00000000..6355f620 --- /dev/null +++ b/kt-sft/ktransformers/dpo/trainer.py @@ -0,0 +1,201 @@ +from trl import DPOTrainer +import os +from packaging import version +import inspect +import functools +from typing import Union, Any, Dict, List + +import torch +from torch.utils.data import DataLoader, IterableDataset +from torch.utils.data import Dataset as TorchDataset + + +from transformers.training_args import OptimizerNames + +from transformers.trainer_utils import seed_worker +from transformers.utils import ( + is_datasets_available, + is_sagemaker_mp_enabled, + is_torch_xpu_available, + is_torch_mlu_available, + is_torch_musa_available, + is_torch_npu_available, + is_torch_mps_available, + is_torch_hpu_available, + is_accelerate_available, + is_apex_available, + logging, +) + +from ktransformers.util.trainer_utils import KAccelerator + + +if is_accelerate_available("0.28.0"): + from accelerate.utils import DataLoaderConfiguration +from accelerate import __version__ as accelerate_version +if version.parse(accelerate_version) > version.parse("1.3.0"): + from accelerate.utils import TorchTensorParallelPlugin +if is_sagemaker_mp_enabled(): + from transformers.trainer_utils import smp_forward_backward + +logger = logging.get_logger(__name__) + +class KTDporainer(DPOTrainer): + def save_model(self, output_dir=None, _internal_call=False): + output_dir = output_dir or self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + # only save LoRA adapter, including adapter_config.json + self.model.save_pretrained(output_dir) + + def _move_model_to_device(self, model, device): + print("[KTrainer] Due to the placement feature in KTransformers, skip moving model to", device) + return model + + def _wrap_model(self, model, training=True, dataloader=None): + self.model_wrapped = model + return model + + def create_accelerator_and_postprocess(self): + # We explicitly don't rely on the `Accelerator` to do gradient accumulation + grad_acc_kwargs = {} + if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: + grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs + + # check if num_steps is attempted to be passed in gradient_accumulation_kwargs + if "num_steps" in grad_acc_kwargs: + if self.args.gradient_accumulation_steps > 1: + # raise because we do not know which setting is intended. + raise ValueError( + "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" + "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." + ) + else: + self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"] + + accelerator_config = self.args.accelerator_config.to_dict() + + if is_accelerate_available("0.28.0"): + # Extract dataloader config params from accelerator config + dataloader_params = ["split_batches", "dispatch_batches", "even_batches", "use_seedable_sampler"] + dataloader_config_dict = {param: accelerator_config.pop(param) for param in dataloader_params if param in accelerator_config} + if DataLoaderConfiguration is None: + raise ImportError("Your accelerate does not provide DataLoaderConfiguration but Trainer expects it.") + dataloader_config = DataLoaderConfiguration(**dataloader_config_dict) + if is_accelerate_available("1.1.0"): + dataloader_config.data_seed = self.args.data_seed + else: + dataloader_config = None + + non_blocking = accelerator_config.pop("non_blocking", False) + if not is_accelerate_available("0.30.0"): + if non_blocking: + raise ImportError( + "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature." + ) + else: + if non_blocking and not self.args.dataloader_pin_memory: + logger.warning \ + ("`non_blocking` is enabled but `dataloader_pin_memory` is not. For best performance, enable both.") + if dataloader_config is not None: + dataloader_config.non_blocking = non_blocking + + accelerator_config.pop("gradient_accumulation_kwargs", None) + + args = { + "deepspeed_plugin": self.args.deepspeed_plugin, + "device_placement": False, + } + + if is_accelerate_available("0.28.0"): + args["dataloader_config"] = dataloader_config + else: + args.update(accelerator_config) + + if getattr(self.args, "tp_size", 1) > 1: + self.is_tp_enabled = True + if version.parse(accelerate_version) > version.parse("1.3.0") and TorchTensorParallelPlugin is not None: + args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size) + else: + raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.") + + self.accelerator = KAccelerator(**args) + + try: + self.accelerator.state.device_ids = [0] + self.accelerator.state.num_processes = 1 + self.accelerator.state.num_gpus = 1 + except Exception: + pass + + # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag + self.gather_function = self.accelerator.gather_for_metrics + + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None + # post accelerator creation setup + if self.is_fsdp_enabled: + fsdp_plugin = self.accelerator.state.fsdp_plugin + for param in ["limit_all_gathers", "activation_checkpointing"]: + setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param))) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError( + "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " + "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " + "when using FSDP." + ) + + if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: + self.propagate_args_to_deepspeed() + + # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end` + if ( + self.args.save_only_model + and (self.is_deepspeed_enabled or self.is_fsdp_enabled) + and self.args.load_best_model_at_end + ): + wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" + raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.") + + # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3 + if ( + self.is_deepspeed_enabled + and self.accelerator.state.deepspeed_plugin.zero_stage == 3 + and self.args.auto_find_batch_size + ): + raise ValueError( + "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP" + ) + if ( + self.args.save_only_model + and self.is_fsdp_enabled + and "SHARDED_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type) + ): + raise ValueError("save_only_model option is not compatible with FSDP state dict type 'SHARDED_STATE_DICT'") + + if dataloader_config is not None: + dataloader_config.split_batches = False + dataloader_config.dispatch_batches = False + dataloader_config.even_batches = False + + def post_training_step(self, loss): + if loss.device != self.args.device: + ret = loss.to(self.args.device, non_blocking=True) + return loss + + def training_step( + self, + model: torch.nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + num_items_in_batch=None + ) -> torch.Tensor: + + ret = super().training_step(model, inputs, num_items_in_batch=num_items_in_batch) + ret = self.post_training_step(ret) + return ret \ No newline at end of file diff --git a/kt-sft/ktransformers/local_chat.py b/kt-sft/ktransformers/local_chat.py index 432782ae..7f2d8990 100644 --- a/kt-sft/ktransformers/local_chat.py +++ b/kt-sft/ktransformers/local_chat.py @@ -111,7 +111,7 @@ def local_chat( torch.set_grad_enabled(False) if is_sft == True or use_adapter == True: - GLOBAL_CONFIG._config["mod"] = "sft" + GLOBAL_CONFIG._config["mod"] = "train" else: GLOBAL_CONFIG._config["mod"] = "infer" @@ -178,13 +178,13 @@ def local_chat( if is_sft == True: if use_adapter == True or is_test_data == True: raise AttributeError("We do not support to run sft and inference at the same time.") - GLOBAL_CONFIG._config["mod"] = "sft" + GLOBAL_CONFIG._config["mod"] = "train" print(f"sft with lora in dataset: {sft_data_path} ...") print(f"use_cuda_graph:{use_cuda_graph}") lora_and_load_adapter(model, tokenizer, sft_data_path, save_adapter_path) if use_adapter == True: - GLOBAL_CONFIG._config["mod"] = "sft" + GLOBAL_CONFIG._config["mod"] = "train" if is_sft == True: raise AttributeError("We do not support more than one adapter up to now...") @@ -261,7 +261,7 @@ def local_chat( # else: # os.system("clear") - if GLOBAL_CONFIG._config["mod"] == "sft" : + if GLOBAL_CONFIG._config["mod"] == "train" : model.model.embed_tokens.to("cpu") if is_test_data: diff --git a/kt-sft/ktransformers/operators/linear.py b/kt-sft/ktransformers/operators/linear.py index d617bcd4..ebb92aad 100644 --- a/kt-sft/ktransformers/operators/linear.py +++ b/kt-sft/ktransformers/operators/linear.py @@ -159,7 +159,7 @@ def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> t dtype = x.dtype out_device = x.device - if (not x.requires_grad) and GLOBAL_CONFIG._config["mod"] == "sft": + if (not x.requires_grad) and GLOBAL_CONFIG._config["mod"] == "train": x = x.requires_grad_(True) # TODO: support CUDA Graph when using cpu, but CPUInfer is recommended. x = x.to(device=self.device, dtype=self.dtype) diff --git a/kt-sft/ktransformers/sft/lora.py b/kt-sft/ktransformers/sft/lora.py index e862dba3..0c3d83ca 100644 --- a/kt-sft/ktransformers/sft/lora.py +++ b/kt-sft/ktransformers/sft/lora.py @@ -35,6 +35,7 @@ import os, json from pathlib import Path from accelerate import Accelerator +from ktransformers.util.trainer_utils import KAccelerator if is_accelerate_available("0.28.0"): from accelerate.utils import DataLoaderConfiguration from accelerate import __version__ as accelerate_version @@ -47,23 +48,6 @@ logger = logging.get_logger(__name__) -class KAccelerator(Accelerator): - def __init__(self, *args, **kwargs): - kwargs.setdefault("device_placement", False) - super().__init__(*args, **kwargs) - - def prepare_model(self, model, *args, **kwargs): - return model - - def prepare(self, *args, **kwargs): - prepped = [] - for obj in args: - if isinstance(obj, nn.Module): - prepped.append(self.prepare_model(obj, **kwargs)) - else: - prepped.append(super().prepare(obj, **kwargs)) - return tuple(prepped) if len(prepped) > 1 else prepped[0] - class KTrainer(Trainer): def save_model(self, output_dir=None, _internal_call=False): output_dir = output_dir or self.args.output_dir diff --git a/kt-sft/ktransformers/util/__init__.py b/kt-sft/ktransformers/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kt-sft/ktransformers/util/grad_wrapper.py b/kt-sft/ktransformers/util/grad_wrapper.py index 33595423..ffddd5ce 100644 --- a/kt-sft/ktransformers/util/grad_wrapper.py +++ b/kt-sft/ktransformers/util/grad_wrapper.py @@ -16,7 +16,7 @@ def decorator(func): # print(f"decorate_sit: {GLOBAL_CONFIG._config['mod']}") def wrapper(*args, **kwargs): # print(f"wrap_sit: {GLOBAL_CONFIG._config['mod']}") - if GLOBAL_CONFIG._config["mod"] == "sft": + if GLOBAL_CONFIG._config["mod"] == "train": return func(*args, **kwargs) elif GLOBAL_CONFIG._config["mod"] == "infer": with torch.no_grad(): diff --git a/kt-sft/ktransformers/util/trainer_utils.py b/kt-sft/ktransformers/util/trainer_utils.py new file mode 100644 index 00000000..7a72bc7c --- /dev/null +++ b/kt-sft/ktransformers/util/trainer_utils.py @@ -0,0 +1,20 @@ + +from accelerate import Accelerator +import torch.nn as nn + +class KAccelerator(Accelerator): + def __init__(self, *args, **kwargs): + kwargs.setdefault("device_placement", False) + super().__init__(*args, **kwargs) + + def prepare_model(self, model, *args, **kwargs): + return model + + def prepare(self, *args, **kwargs): + prepped = [] + for obj in args: + if isinstance(obj, nn.Module): + prepped.append(self.prepare_model(obj, **kwargs)) + else: + prepped.append(super().prepare(obj, **kwargs)) + return tuple(prepped) if len(prepped) > 1 else prepped[0] \ No newline at end of file diff --git a/kt-sft/ktransformers/util/utils.py b/kt-sft/ktransformers/util/utils.py index 1432ec76..5a8f4781 100644 --- a/kt-sft/ktransformers/util/utils.py +++ b/kt-sft/ktransformers/util/utils.py @@ -164,7 +164,7 @@ def get_all_used_cuda_device(device_map:dict): return all_device_list def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda", adapter_gguf: bool = False): - if GLOBAL_CONFIG._config["mod"] == 'sft': + if GLOBAL_CONFIG._config["mod"] == 'train': prefix = prefix.replace("orig_module.", "") persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items()) From 759dd734550d568595275b660c1d5e264e17cf6f Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 20 Nov 2025 13:46:01 +0800 Subject: [PATCH 2/4] fix bug for KTDpoTrainer --- kt-sft/ktransformers/dpo/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kt-sft/ktransformers/dpo/trainer.py b/kt-sft/ktransformers/dpo/trainer.py index 6355f620..c54b66af 100644 --- a/kt-sft/ktransformers/dpo/trainer.py +++ b/kt-sft/ktransformers/dpo/trainer.py @@ -40,7 +40,7 @@ logger = logging.get_logger(__name__) -class KTDporainer(DPOTrainer): +class KTDpoTrainer(DPOTrainer): def save_model(self, output_dir=None, _internal_call=False): output_dir = output_dir or self.args.output_dir os.makedirs(output_dir, exist_ok=True) @@ -186,7 +186,7 @@ def create_accelerator_and_postprocess(self): def post_training_step(self, loss): if loss.device != self.args.device: - ret = loss.to(self.args.device, non_blocking=True) + loss = loss.to(self.args.device, non_blocking=True) return loss def training_step( From 9bad071bae11e8a399f16d0a756d64d4d41c8671 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 20 Nov 2025 13:55:42 +0800 Subject: [PATCH 3/4] add trl lib requirement --- kt-sft/requirements-sft.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kt-sft/requirements-sft.txt b/kt-sft/requirements-sft.txt index 3be1a44e..8b1be955 100644 --- a/kt-sft/requirements-sft.txt +++ b/kt-sft/requirements-sft.txt @@ -30,4 +30,5 @@ torchviz==0.0.3 tzdata==2025.2 xxhash==3.5.0 yarl==1.20.0 -torchviz \ No newline at end of file +torchviz +trl==0.9.6 \ No newline at end of file From caac15d95826cc37a274b738e9c6b1fad77972ca Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 25 Nov 2025 19:06:40 +0800 Subject: [PATCH 4/4] =?UTF-8?q?1=E3=80=81remove=20invalid=20loss=20comput?= =?UTF-8?q?=20for=20dpo=20to=20save=20mem;2=E3=80=81add=20qwen3-235-muti-g?= =?UTF-8?q?pu=20optimizer=20rule?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- kt-sft/ktransformers/dpo/trainer.py | 84 +++++++++- .../Qwen3Moe-235b-sft-gpu-2-amx.yaml | 157 ++++++++++++++++++ kt-sft/ktransformers/util/trainer_utils.py | 70 +++++++- 3 files changed, 308 insertions(+), 3 deletions(-) create mode 100644 kt-sft/ktransformers/optimize/optimize_rules/Qwen3Moe-235b-sft-gpu-2-amx.yaml diff --git a/kt-sft/ktransformers/dpo/trainer.py b/kt-sft/ktransformers/dpo/trainer.py index c54b66af..9d56d585 100644 --- a/kt-sft/ktransformers/dpo/trainer.py +++ b/kt-sft/ktransformers/dpo/trainer.py @@ -4,6 +4,7 @@ import inspect import functools from typing import Union, Any, Dict, List +from typing_extensions import override import torch from torch.utils.data import DataLoader, IterableDataset @@ -27,7 +28,7 @@ logging, ) -from ktransformers.util.trainer_utils import KAccelerator +from ktransformers.util.trainer_utils import KAccelerator, nested_detach, get_batch_logps if is_accelerate_available("0.28.0"): @@ -184,6 +185,53 @@ def create_accelerator_and_postprocess(self): dataloader_config.dispatch_batches = False dataloader_config.even_batches = False + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training DataLoader with per_device_train_batch_size + (no implicit multipliers by number of visible GPUs). + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + + if is_datasets_available(): + try: + import datasets + if isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + except Exception: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self.args.per_device_train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + if self.args.dataloader_num_workers > 0 and self.args.dataloader_prefetch_factor is not None: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + dl = DataLoader(train_dataset, **dataloader_params) + + try: + prepared = self.accelerator.prepare(dl, device_placement=[False]) + except TypeError: + prepared = self.accelerator.prepare(dl) + + return prepared + def post_training_step(self, loss): if loss.device != self.args.device: loss = loss.to(self.args.device, non_blocking=True) @@ -198,4 +246,36 @@ def training_step( ret = super().training_step(model, inputs, num_items_in_batch=num_items_in_batch) ret = self.post_training_step(ret) - return ret \ No newline at end of file + return ret + + @override + def concatenated_forward( + self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False + ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. + + Otherwise the average log probabilities. + """ + if self.finetuning_args.use_ref_model: + batch = nested_detach(batch, clone=True) # avoid error + labels = batch["labels"] + # dpo not need compute loss in forward, waste mem + del batch["labels"] + all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) + all_logits = all_logits.to("cpu") + labels = labels.to(all_logits.device) + all_logps, valid_length = get_batch_logps( + logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None) + ) + if self.loss_type in ["ipo", "orpo", "simpo"]: + all_logps = all_logps / valid_length + + batch_size = batch["input_ids"].size(0) // 2 + chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) + chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) + chosen_length, _ = valid_length.split(batch_size, dim=0) + + if self.loss_type in ["ipo", "orpo", "simpo"]: + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps + else: + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length \ No newline at end of file diff --git a/kt-sft/ktransformers/optimize/optimize_rules/Qwen3Moe-235b-sft-gpu-2-amx.yaml b/kt-sft/ktransformers/optimize/optimize_rules/Qwen3Moe-235b-sft-gpu-2-amx.yaml new file mode 100644 index 00000000..cafdd339 --- /dev/null +++ b/kt-sft/ktransformers/optimize/optimize_rules/Qwen3Moe-235b-sft-gpu-2-amx.yaml @@ -0,0 +1,157 @@ +- match: + name: "^model\\.layers\\.(0|[1-9]|[1234][0-9])\\." + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([56789][0-9])\\." + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model\\.layers\\.(0|[1-9]|[1234][0-9])\\." # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "KLinearTorch" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\.([56789][0-9])\\." # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + generate_op: "KLinearTorch" + prefill_op: "KLinearTorch" + +#- match: +# name: "^model\\.layers\\.(0|[1-9]|[1234][0-9])\\.(?!mlp.gate).*$" # regular expression +# class: torch.nn.Linear # only match modules matching name and class simultaneously +# replace: +# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types +# kwargs: +# generate_device: "cuda:0" +# prefill_device: "cuda:0" +# generate_op: "KLinearTorch" +# prefill_op: "KLinearTorch" +#- match: +# name: "^model\\.layers\\.([56789][0-9])\\.(?!mlp.gate).*$" # regular expression +# class: torch.nn.Linear # only match modules matching name and class simultaneously +# replace: +# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types +# kwargs: +# generate_device: "cuda:1" +# prefill_device: "cuda:1" +# generate_op: "KLinearTorch" +# prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[1234][0-9])\\.mlp$" + replace: + class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock # mlp module with custom forward function + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([56789][0-9])\\.mlp$" + replace: + class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock # mlp module with custom forward function + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[1234][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:0" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KSFTExpertsCPU" + out_device: "cuda:0" + backend: "AMXInt8" # or "AMXInt8" or "AMXBF16" or "llamafile" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.([56789][0-9])\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda:1" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KSFTExpertsCPU" + out_device: "cuda:1" + backend: "AMXInt8" # or "AMXInt8" or "AMXBF16" or "llamafile" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\.(0|[1-9]|[1234][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KQwen3MoeAttention # optimized MLA implementation + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\.([56789][0-9])\\.self_attn$" + replace: + class: ktransformers.operators.attention.KQwen3MoeAttention # optimized MLA implementation + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KQwen2MoeModel" + kwargs: + per_layer_prefill_intput_threshold: 0 + transfer_map: + 50: "cuda:1" + +- match: + name: "^lm_head" + class: torch.nn.Linear + replace: + class: ktransformers.operators.linear.KTransformersLinear + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + generate_op: "KLinearTorch" + prefill_op: "KLinearTorch" + +- match: + name: "(^model\\.layers\\.(0|[1-9]|[1234][0-9])\\.)" + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + + +- match: + name: "(^model\\.layers\\.([56789][0-9])\\.)|(model.norm)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" diff --git a/kt-sft/ktransformers/util/trainer_utils.py b/kt-sft/ktransformers/util/trainer_utils.py index 7a72bc7c..a5902ef8 100644 --- a/kt-sft/ktransformers/util/trainer_utils.py +++ b/kt-sft/ktransformers/util/trainer_utils.py @@ -1,6 +1,9 @@ +from collections.abc import Mapping +from typing import Union, Optional from accelerate import Accelerator import torch.nn as nn +import torch class KAccelerator(Accelerator): def __init__(self, *args, **kwargs): @@ -17,4 +20,69 @@ def prepare(self, *args, **kwargs): prepped.append(self.prepare_model(obj, **kwargs)) else: prepped.append(super().prepare(obj, **kwargs)) - return tuple(prepped) if len(prepped) > 1 else prepped[0] \ No newline at end of file + return tuple(prepped) if len(prepped) > 1 else prepped[0] + + +def nested_detach( + tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]], + clone: bool = False, +): + r"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors).""" + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_detach(t, clone=clone) for t in tensors) + elif isinstance(tensors, Mapping): + return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()}) + + if isinstance(tensors, torch.Tensor): + if clone: + return tensors.detach().clone() + else: + return tensors.detach() + else: + return tensors + +def get_batch_logps( + logits: "torch.Tensor", + labels: "torch.Tensor", + label_pad_token_id: int = -100, + ld_alpha: Optional[float] = None, +) -> tuple["torch.Tensor", "torch.Tensor"]: + r"""Compute the log probabilities of the given labels under the given logits. + + Returns: + logps: A tensor of shape (batch_size,) containing the sum of log probabilities. + valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens. + + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.") + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + labels[labels == label_pad_token_id] = 0 # dummy token + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + valid_length = loss_mask.sum(-1) + if ld_alpha is not None: + num_examples = labels.shape[0] // 2 + chosen_lengths = valid_length[:num_examples] + rejected_lengths = valid_length[num_examples:] + min_lengths = torch.min(chosen_lengths, rejected_lengths) + start_positions = torch.argmax(loss_mask.int(), dim=1) + public_lengths = start_positions + torch.cat([min_lengths, min_lengths], dim=0) + + seq_len = labels.shape[-1] + position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) + + ld_mask = position_ids < public_lengths.unsqueeze(1) + front_mask = (ld_mask * loss_mask).float() + rear_mask = (~ld_mask * loss_mask).float() + + front_logps = (per_token_logps * front_mask).sum(-1) + rear_logps = (per_token_logps * rear_mask).sum(-1) + logps = front_logps + ld_alpha * rear_logps + else: + logps = (per_token_logps * loss_mask).sum(-1) + + return logps, valid_length \ No newline at end of file