Skip to content

Commit e27f4dd

Browse files
authored
Merge pull request #79 from kadirnar/update-test
💬 Add new parameters for hqq optimization method
2 parents 317bdc9 + e5ceca7 commit e27f4dd

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

scripts/runpod.sh

100644100755
File mode changed.

whisperplus/pipelines/whisper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import logging
22

33
import torch
4+
from hqq.core.quantize import HQQBackend, HQQLinear
45
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
56

7+
HQQLinear.set_backend(HQQBackend.PYTORCH) # Pytorch backend
8+
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) # Compiled Pytorch via dynamo
9+
HQQLinear.set_backend(HQQBackend.ATEN) # C++ Aten/CUDA backend (set automatically by default if available)
10+
611
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
712

813

@@ -25,7 +30,7 @@ def load_model(self, model_id: str = "distil-whisper/distil-large-v3", quant_con
2530
low_cpu_mem_usage=True,
2631
use_safetensors=True,
2732
attn_implementation="flash_attention_2",
28-
torch_dtype=torch.float16,
33+
torch_dtype=torch.bfloat16,
2934
device_map='auto',
3035
max_memory={0: "24GiB"})
3136
logging.info("Model loaded successfully.")

whisperplus/test.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1-
import time
2-
31
import torch
2+
from hqq.utils.patching import prepare_for_inference
43
from pipelines.whisper import SpeechToTextPipeline
54
from transformers import BitsAndBytesConfig, HqqConfig
65
from utils.download_utils import download_and_convert_to_mp3
76

8-
url = "https://www.youtube.com/watch?v=di3rHkEZuUw"
7+
url = "https://www.youtube.com/watch?v=BpN4hEAvDBg"
98
audio_path = download_and_convert_to_mp3(url)
109

1110
hqq_config = HqqConfig(
12-
nbits=1, group_size=64, quant_zero=False, quant_scale=False, axis=0) # axis=0 is used by default
11+
nbits=1,
12+
group_size=64,
13+
quant_zero=False,
14+
quant_scale=False,
15+
axis=0,
16+
offload_meta=False,
17+
) # axis=0 is used by default
1318

1419
bnb_config = BitsAndBytesConfig(
1520
load_in_4bit=True,
@@ -18,14 +23,14 @@
1823
bnb_4bit_use_double_quant=True,
1924
)
2025
model = SpeechToTextPipeline(
21-
model_id="distil-whisper/distil-large-v3", quant_config=bnb_config) # or bnb_config
26+
model_id="distil-whisper/distil-large-v3", quant_config=hqq_config) # or bnb_config
2227

2328
start_event = torch.cuda.Event(enable_timing=True)
2429
end_event = torch.cuda.Event(enable_timing=True)
2530

2631
start_event.record()
2732
transcript = model(
28-
audio_path="testv0.mp3",
33+
audio_path=audio_path,
2934
chunk_length_s=30,
3035
stride_length_s=5,
3136
max_new_tokens=128,
@@ -36,4 +41,5 @@
3641

3742
torch.cuda.synchronize()
3843
elapsed_time_ms = start_event.elapsed_time(end_event)
39-
print(f"Execution time: {elapsed_time_ms}ms")
44+
seconds = elapsed_time_ms / 1000
45+
print(f"Elapsed time: {seconds} seconds")

0 commit comments

Comments
 (0)