Skip to content

Conversation

Deep-unlearning
Copy link

What does this PR do?

Add usage example for the Granite Speech models

@eustlb

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, nice initiative! 🤗
nits but we can merge after

Comment on lines +64 to +65
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
wav = torch.tensor(dataset[0]["audio"]["array"]).unsqueeze(0) # add batch dimension
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

starting from datasets 4.0 this should be directly audio.get_all_samples etc see doc

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and processor
model_name = "ibm-granite/granite-speech-3.3-8b"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but let's use rather model_id

from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from datasets import load_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rather use rather device_map="auto" in from_pretrained

processor = AutoProcessor.from_pretrained(model_name)
tokenizer = processor.tokenizer
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name, device_map=device, torch_dtype=torch.bfloat16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_dtype is deprecated! let's use dtype. Also dtype="auto" here since
"torch_dtype": "bfloat16" in config.json

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants