Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import itertools

# Third Party
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.amp import custom_bwd, custom_fwd
import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -140,15 +140,15 @@ def quant_matmul_248(

class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@custom_fwd(device_type="cuda")
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import torch

# Local
from .utils import lora_adapters_switch_ddp_from_fsdp
from .models.utils import filter_mp_rules
from .utils import lora_adapters_switch_ddp_from_fsdp


# consider rewriting register_foak_model_patch_rules into something
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class LoRA_MLP(torch.autograd.Function):
Don't forget to see our blog post for more details!
"""
@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx, X : torch.Tensor,
gateW, gateW_quant, gate_bias, gateA, gateB, gateS,
upW, upW_quant, up_bias, upA, upB, upS,
Expand Down Expand Up @@ -104,7 +104,7 @@ def forward(ctx, X : torch.Tensor,


@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dY : torch.Tensor):
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
_backward_function = ctx.custom_saved_tensors
Expand Down Expand Up @@ -251,7 +251,7 @@ class LoRA_QKV(torch.autograd.Function):
dC/dBv = A.T @ X.T @ D(Wv)
"""
@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx, X : torch.Tensor,
QW, QW_quant, Q_bias, QA, QB, QS,
KW, KW_quant, K_bias, KA, KB, KS,
Expand Down Expand Up @@ -294,7 +294,7 @@ def forward(ctx, X : torch.Tensor,
pass

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dQ, dK, dV):
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
ctx.custom_saved_tensors
Expand Down Expand Up @@ -404,7 +404,7 @@ class LoRA_W(torch.autograd.Function):
dC/dBv = A.T @ X.T @ D(Wv)
"""
@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx, X : torch.Tensor,
W, W_quant, bias, A, B, S, dropout_O):
dtype = X.dtype
Expand All @@ -423,7 +423,7 @@ def forward(ctx, X : torch.Tensor,
pass

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dY : torch.Tensor):
W, W_quant, S = ctx.custom_saved_tensors
A, B, X, OX = ctx.saved_tensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

# with modifications from The IBM Tuning Team

import math
from dataclasses import dataclass
from logging import getLogger
from typing import Optional

import torch
from torch.cuda.amp import custom_bwd, custom_fwd

from .triton.kernels import dequant248
from ..swiglu import swiglu_DWf_DW_dfg_kernel, swiglu_fg_kernel
Expand Down Expand Up @@ -213,7 +211,7 @@ class LoRA_MLP(torch.autograd.Function):
"""

@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(
ctx,
X: torch.Tensor,
Expand Down Expand Up @@ -309,7 +307,7 @@ def forward(
return i

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dY: torch.Tensor):
(
gate_qweight,
Expand Down Expand Up @@ -497,7 +495,7 @@ class LoRA_QKV(torch.autograd.Function):
"""

@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(
ctx,
X: torch.Tensor,
Expand Down Expand Up @@ -591,7 +589,7 @@ def forward(
return Q, K, V

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dQ, dK, dV):
(
Q_qweight,
Expand Down Expand Up @@ -770,7 +768,7 @@ class LoRA_W(torch.autograd.Function):
"""

@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(
ctx,
X: torch.Tensor,
Expand Down Expand Up @@ -807,7 +805,7 @@ def forward(
return XW

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, dY: torch.Tensor):
O_qweight, O_scales, O_qzeros, O_g_idx, O_bits, S = ctx.custom_saved_tensors
A, B, X, OX = ctx.saved_tensors
Expand Down