-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Add granit speech doc #41360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add granit speech doc #41360
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, nice initiative! 🤗
nits but we can merge after
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
wav = torch.tensor(dataset[0]["audio"]["array"]).unsqueeze(0) # add batch dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
starting from datasets 4.0 this should be directly audio.get_all_samples etc see doc
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
# Load model and processor | ||
model_name = "ibm-granite/granite-speech-3.3-8b" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but let's use rather model_id
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | ||
from datasets import load_dataset | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rather use rather device_map="auto" in from_pretrained
processor = AutoProcessor.from_pretrained(model_name) | ||
tokenizer = processor.tokenizer | ||
model = AutoModelForSpeechSeq2Seq.from_pretrained( | ||
model_name, device_map=device, torch_dtype=torch.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch_dtype is deprecated! let's use dtype
. Also dtype="auto" here since
"torch_dtype": "bfloat16" in config.json
What does this PR do?
Add usage example for the Granite Speech models
@eustlb