diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/nn_modules/triton_utils/dequant.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/nn_modules/triton_utils/dequant.py index fc601bd6..048dc78a 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/nn_modules/triton_utils/dequant.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/nn_modules/triton_utils/dequant.py @@ -17,7 +17,7 @@ import itertools # Third Party -from torch.cuda.amp import custom_bwd, custom_fwd +from torch.amp import custom_bwd, custom_fwd import torch import triton import triton.language as tl @@ -140,7 +140,7 @@ def quant_matmul_248( class QuantLinearFunction(torch.autograd.Function): @staticmethod - @custom_fwd + @custom_fwd(device_type="cuda") def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq) ctx.save_for_backward(qweight, scales, qzeros, g_idx) @@ -148,7 +148,7 @@ def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): return output @staticmethod - @custom_bwd + @custom_bwd(device_type="cuda") def backward(ctx, grad_output): qweight, scales, qzeros, g_idx = ctx.saved_tensors bits, maxq = ctx.bits, ctx.maxq diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index 0bf35fbb..ea6b3e78 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -23,8 +23,8 @@ import torch # Local -from .utils import lora_adapters_switch_ddp_from_fsdp from .models.utils import filter_mp_rules +from .utils import lora_adapters_switch_ddp_from_fsdp # consider rewriting register_foak_model_patch_rules into something diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py index 55e35ca6..6eee0353 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py @@ -57,7 +57,7 @@ class LoRA_MLP(torch.autograd.Function): Don't forget to see our blog post for more details! """ @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.custom_fwd(device_type='cuda') def forward(ctx, X : torch.Tensor, gateW, gateW_quant, gate_bias, gateA, gateB, gateS, upW, upW_quant, up_bias, upA, upB, upS, @@ -104,7 +104,7 @@ def forward(ctx, X : torch.Tensor, @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, dY : torch.Tensor): gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \ _backward_function = ctx.custom_saved_tensors @@ -251,7 +251,7 @@ class LoRA_QKV(torch.autograd.Function): dC/dBv = A.T @ X.T @ D(Wv) """ @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.custom_fwd(device_type='cuda') def forward(ctx, X : torch.Tensor, QW, QW_quant, Q_bias, QA, QB, QS, KW, KW_quant, K_bias, KA, KB, KS, @@ -294,7 +294,7 @@ def forward(ctx, X : torch.Tensor, pass @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, dQ, dK, dV): QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \ ctx.custom_saved_tensors @@ -404,7 +404,7 @@ class LoRA_W(torch.autograd.Function): dC/dBv = A.T @ X.T @ D(Wv) """ @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.custom_fwd(device_type='cuda') def forward(ctx, X : torch.Tensor, W, W_quant, bias, A, B, S, dropout_O): dtype = X.dtype @@ -423,7 +423,7 @@ def forward(ctx, X : torch.Tensor, pass @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, dY : torch.Tensor): W, W_quant, S = ctx.custom_saved_tensors A, B, X, OX = ctx.saved_tensors diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py index 31aa5d5e..633d8773 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py @@ -4,13 +4,11 @@ # with modifications from The IBM Tuning Team -import math from dataclasses import dataclass from logging import getLogger from typing import Optional import torch -from torch.cuda.amp import custom_bwd, custom_fwd from .triton.kernels import dequant248 from ..swiglu import swiglu_DWf_DW_dfg_kernel, swiglu_fg_kernel @@ -213,7 +211,7 @@ class LoRA_MLP(torch.autograd.Function): """ @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.custom_fwd(device_type='cuda') def forward( ctx, X: torch.Tensor, @@ -309,7 +307,7 @@ def forward( return i @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, dY: torch.Tensor): ( gate_qweight, @@ -497,7 +495,7 @@ class LoRA_QKV(torch.autograd.Function): """ @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.custom_fwd(device_type='cuda') def forward( ctx, X: torch.Tensor, @@ -591,7 +589,7 @@ def forward( return Q, K, V @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, dQ, dK, dV): ( Q_qweight, @@ -770,7 +768,7 @@ class LoRA_W(torch.autograd.Function): """ @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.custom_fwd(device_type='cuda') def forward( ctx, X: torch.Tensor, @@ -807,7 +805,7 @@ def forward( return XW @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, dY: torch.Tensor): O_qweight, O_scales, O_qzeros, O_g_idx, O_bits, S = ctx.custom_saved_tensors A, B, X, OX = ctx.saved_tensors