File tree Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Original file line number Diff line number Diff line change 8
8
from packaging .version import Version
9
9
from torch .distributed import ReduceOp
10
10
11
+ from .fp8_config import dynamic_kernel
12
+
11
13
SUPPORT_TORCH_COMPILE = Version (torch .__version__ ) >= Version ("2.4.0" )
12
14
SCALE_BYTES = 4
13
15
try :
@@ -832,11 +834,13 @@ def backward(ctx: Any, out_grad) -> Any:
832
834
return x_grad .reshape (ctx .x_shape ), w_grad , bias_grad
833
835
834
836
835
- @torch .compile (mode = "max-autotune-no-cudagraphs" , disable = not SUPPORT_TORCH_COMPILE , dynamic = False )
837
+ @torch .compile (mode = "max-autotune-no-cudagraphs" , disable = not SUPPORT_TORCH_COMPILE , dynamic = dynamic_kernel )
836
838
def _linear_fp8 (input : torch .Tensor , weight : torch .Tensor , bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
837
839
return _LinearFp8 .apply (input , weight , bias )
838
840
839
841
840
842
def linear_fp8 (input : torch .Tensor , weight : torch .Tensor , bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
843
+ if input .shape [- 1 ] % 16 != 0 or np .prod (input .shape [:- 1 ]) % 16 != 0 :
844
+ return F .linear (input , weight , bias )
841
845
out = _linear_fp8 (input , weight , bias )
842
846
return out
Original file line number Diff line number Diff line change
1
+ dynamic_kernel : bool = False
You can’t perform that action at this time.
0 commit comments