File tree Expand file tree Collapse file tree 2 files changed +6
-1
lines changed
Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Original file line number Diff line number Diff line change 88from packaging .version import Version
99from torch .distributed import ReduceOp
1010
11+ from .fp8_config import dynamic_kernel
12+
1113SUPPORT_TORCH_COMPILE = Version (torch .__version__ ) >= Version ("2.4.0" )
1214SCALE_BYTES = 4
1315try :
@@ -832,11 +834,13 @@ def backward(ctx: Any, out_grad) -> Any:
832834 return x_grad .reshape (ctx .x_shape ), w_grad , bias_grad
833835
834836
835- @torch .compile (mode = "max-autotune-no-cudagraphs" , disable = not SUPPORT_TORCH_COMPILE , dynamic = False )
837+ @torch .compile (mode = "max-autotune-no-cudagraphs" , disable = not SUPPORT_TORCH_COMPILE , dynamic = dynamic_kernel )
836838def _linear_fp8 (input : torch .Tensor , weight : torch .Tensor , bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
837839 return _LinearFp8 .apply (input , weight , bias )
838840
839841
840842def linear_fp8 (input : torch .Tensor , weight : torch .Tensor , bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
843+ if input .shape [- 1 ] % 16 != 0 or np .prod (input .shape [:- 1 ]) % 16 != 0 :
844+ return F .linear (input , weight , bias )
841845 out = _linear_fp8 (input , weight , bias )
842846 return out
Original file line number Diff line number Diff line change 1+ dynamic_kernel : bool = False
You can’t perform that action at this time.
0 commit comments