Skip to content

Commit c70ffe0

Browse files
fix: Deprecation Warnings in AutoCast API (#113)
* fix: warning of torch.amp.custom_bwd Signed-off-by: Abhishek <[email protected]> * fix: warning of torch.amp.custom_bwd Signed-off-by: Abhishek <[email protected]> * fix: warning of torch.amp.custom_bwd Signed-off-by: Abhishek <[email protected]> * fix: warning of torch.amp.custom_bwd Signed-off-by: Abhishek <[email protected]> * fix: warning of torch.amp.custom_bwd Signed-off-by: Abhishek <[email protected]> * fix: fmt, lint Signed-off-by: Abhishek <[email protected]> * fix: warning of torch.amp.custom_bwd Signed-off-by: Abhishek <[email protected]> --------- Signed-off-by: Abhishek <[email protected]>
1 parent e7a0e2f commit c70ffe0

File tree

4 files changed

+16
-18
lines changed

4 files changed

+16
-18
lines changed

plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/nn_modules/triton_utils/dequant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import itertools
1818

1919
# Third Party
20-
from torch.cuda.amp import custom_bwd, custom_fwd
20+
from torch.amp import custom_bwd, custom_fwd
2121
import torch
2222
import triton
2323
import triton.language as tl
@@ -140,15 +140,15 @@ def quant_matmul_248(
140140

141141
class QuantLinearFunction(torch.autograd.Function):
142142
@staticmethod
143-
@custom_fwd
143+
@custom_fwd(device_type="cuda")
144144
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
145145
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
146146
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
147147
ctx.bits, ctx.maxq = bits, maxq
148148
return output
149149

150150
@staticmethod
151-
@custom_bwd
151+
@custom_bwd(device_type="cuda")
152152
def backward(ctx, grad_output):
153153
qweight, scales, qzeros, g_idx = ctx.saved_tensors
154154
bits, maxq = ctx.bits, ctx.maxq

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import torch
2424

2525
# Local
26-
from .utils import lora_adapters_switch_ddp_from_fsdp
2726
from .models.utils import filter_mp_rules
27+
from .utils import lora_adapters_switch_ddp_from_fsdp
2828

2929

3030
# consider rewriting register_foak_model_patch_rules into something

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class LoRA_MLP(torch.autograd.Function):
5757
Don't forget to see our blog post for more details!
5858
"""
5959
@staticmethod
60-
@torch.cuda.amp.custom_fwd
60+
@torch.amp.custom_fwd(device_type='cuda')
6161
def forward(ctx, X : torch.Tensor,
6262
gateW, gateW_quant, gate_bias, gateA, gateB, gateS,
6363
upW, upW_quant, up_bias, upA, upB, upS,
@@ -104,7 +104,7 @@ def forward(ctx, X : torch.Tensor,
104104

105105

106106
@staticmethod
107-
@torch.cuda.amp.custom_bwd
107+
@torch.amp.custom_bwd(device_type='cuda')
108108
def backward(ctx, dY : torch.Tensor):
109109
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
110110
_backward_function = ctx.custom_saved_tensors
@@ -251,7 +251,7 @@ class LoRA_QKV(torch.autograd.Function):
251251
dC/dBv = A.T @ X.T @ D(Wv)
252252
"""
253253
@staticmethod
254-
@torch.cuda.amp.custom_fwd
254+
@torch.amp.custom_fwd(device_type='cuda')
255255
def forward(ctx, X : torch.Tensor,
256256
QW, QW_quant, Q_bias, QA, QB, QS,
257257
KW, KW_quant, K_bias, KA, KB, KS,
@@ -294,7 +294,7 @@ def forward(ctx, X : torch.Tensor,
294294
pass
295295

296296
@staticmethod
297-
@torch.cuda.amp.custom_bwd
297+
@torch.amp.custom_bwd(device_type='cuda')
298298
def backward(ctx, dQ, dK, dV):
299299
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
300300
ctx.custom_saved_tensors
@@ -404,7 +404,7 @@ class LoRA_W(torch.autograd.Function):
404404
dC/dBv = A.T @ X.T @ D(Wv)
405405
"""
406406
@staticmethod
407-
@torch.cuda.amp.custom_fwd
407+
@torch.amp.custom_fwd(device_type='cuda')
408408
def forward(ctx, X : torch.Tensor,
409409
W, W_quant, bias, A, B, S, dropout_O):
410410
dtype = X.dtype
@@ -423,7 +423,7 @@ def forward(ctx, X : torch.Tensor,
423423
pass
424424

425425
@staticmethod
426-
@torch.cuda.amp.custom_bwd
426+
@torch.amp.custom_bwd(device_type='cuda')
427427
def backward(ctx, dY : torch.Tensor):
428428
W, W_quant, S = ctx.custom_saved_tensors
429429
A, B, X, OX = ctx.saved_tensors

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44

55
# with modifications from The IBM Tuning Team
66

7-
import math
87
from dataclasses import dataclass
98
from logging import getLogger
109
from typing import Optional
1110

1211
import torch
13-
from torch.cuda.amp import custom_bwd, custom_fwd
1412

1513
from .triton.kernels import dequant248
1614
from ..swiglu import swiglu_DWf_DW_dfg_kernel, swiglu_fg_kernel
@@ -213,7 +211,7 @@ class LoRA_MLP(torch.autograd.Function):
213211
"""
214212

215213
@staticmethod
216-
@torch.cuda.amp.custom_fwd
214+
@torch.amp.custom_fwd(device_type='cuda')
217215
def forward(
218216
ctx,
219217
X: torch.Tensor,
@@ -309,7 +307,7 @@ def forward(
309307
return i
310308

311309
@staticmethod
312-
@torch.cuda.amp.custom_bwd
310+
@torch.amp.custom_bwd(device_type='cuda')
313311
def backward(ctx, dY: torch.Tensor):
314312
(
315313
gate_qweight,
@@ -497,7 +495,7 @@ class LoRA_QKV(torch.autograd.Function):
497495
"""
498496

499497
@staticmethod
500-
@torch.cuda.amp.custom_fwd
498+
@torch.amp.custom_fwd(device_type='cuda')
501499
def forward(
502500
ctx,
503501
X: torch.Tensor,
@@ -591,7 +589,7 @@ def forward(
591589
return Q, K, V
592590

593591
@staticmethod
594-
@torch.cuda.amp.custom_bwd
592+
@torch.amp.custom_bwd(device_type='cuda')
595593
def backward(ctx, dQ, dK, dV):
596594
(
597595
Q_qweight,
@@ -770,7 +768,7 @@ class LoRA_W(torch.autograd.Function):
770768
"""
771769

772770
@staticmethod
773-
@torch.cuda.amp.custom_fwd
771+
@torch.amp.custom_fwd(device_type='cuda')
774772
def forward(
775773
ctx,
776774
X: torch.Tensor,
@@ -807,7 +805,7 @@ def forward(
807805
return XW
808806

809807
@staticmethod
810-
@torch.cuda.amp.custom_bwd
808+
@torch.amp.custom_bwd(device_type='cuda')
811809
def backward(ctx, dY: torch.Tensor):
812810
O_qweight, O_scales, O_qzeros, O_g_idx, O_bits, S = ctx.custom_saved_tensors
813811
A, B, X, OX = ctx.saved_tensors

0 commit comments

Comments
 (0)