Skip to content

Commit e7ff7bf

Browse files
committed
.
1 parent 332d98a commit e7ff7bf

File tree

4 files changed

+11
-13
lines changed

4 files changed

+11
-13
lines changed

auto_fp8/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
from .modeling import AutoFP8ForCausalLM
22
from .config import BaseQuantizeConfig
3+
4+
__all__ = [
5+
"AutoFP8ForCausalLM",
6+
"BaseQuantizeConfig",
7+
]

auto_fp8/modeling.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
2+
from transformers import AutoModelForCausalLM, PreTrainedModel
33
from auto_fp8.quantize import (
44
quantize_weights,
55
quantize_activations,
@@ -14,8 +14,6 @@ def __init__(
1414
model: PreTrainedModel,
1515
quantize_config: BaseQuantizeConfig,
1616
):
17-
# super().__init__()
18-
1917
self.model = model
2018
self.model_type = self.model.config.model_type
2119
self.quantize_config = quantize_config
@@ -30,11 +28,6 @@ def from_pretrained(
3028
):
3129
"""Load the un-quantized pretrained model"""
3230

33-
# if not torch.cuda.is_available():
34-
# raise EnvironmentError(
35-
# "Load pretrained model to do quantization requires CUDA available."
36-
# )
37-
3831
def skip(*args, **kwargs):
3932
pass
4033

auto_fp8/quantize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
import re
33
from typing import Tuple
44
import torch
5-
import torch.functional as F
65
import transformers
76
import tqdm
8-
from transformers import AutoModelForCausalLM, AutoTokenizer
7+
from transformers import AutoTokenizer
98

109

1110
# HACK: Override the dtype_byte_size function in transformers to support float8 types
@@ -59,8 +58,10 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
5958

6059

6160
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
62-
cuda_compute_capability = torch.cuda.get_device_capability()
63-
if cuda_compute_capability >= (9, 0):
61+
native_fp8_support = (
62+
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
63+
)
64+
if native_fp8_support:
6465
output, _ = torch._scaled_mm(
6566
A,
6667
B.t(),

examples/quantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Tuple
55

66
import torch
7-
import torch.functional as F
87
import transformers
98
import tqdm
109
from datasets import load_dataset

0 commit comments

Comments
 (0)