Skip to content

Commit 5ddad48

Browse files
authored
[fp8] add fallback and make compile option configurable (#6092)
1 parent 3b1d7d1 commit 5ddad48

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

colossalai/quantization/fp8.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from packaging.version import Version
99
from torch.distributed import ReduceOp
1010

11+
from .fp8_config import dynamic_kernel
12+
1113
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
1214
SCALE_BYTES = 4
1315
try:
@@ -832,11 +834,13 @@ def backward(ctx: Any, out_grad) -> Any:
832834
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
833835

834836

835-
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False)
837+
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
836838
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
837839
return _LinearFp8.apply(input, weight, bias)
838840

839841

840842
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
843+
if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0:
844+
return F.linear(input, weight, bias)
841845
out = _linear_fp8(input, weight, bias)
842846
return out

colossalai/quantization/fp8_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
dynamic_kernel: bool = False

0 commit comments

Comments
 (0)