Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions examples/tutorials/audio_resampling_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
print(torch.__version__)
print(torchaudio.__version__)

if torch.accelerator.is_available():
acc = torch.accelerator.current_accelerator()
device = torch.device(acc)
else:
device = torch.device("cpu")
print(f"Using device: {device}")

######################################################################
# Preparation
# -----------
Expand All @@ -41,7 +48,7 @@ def _get_log_freq(sample_rate, max_sweep_rate, offset):

"""
start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double, device=device)) - offset


def _get_inverse_log_freq(freq, sample_rate, offset):
Expand Down Expand Up @@ -91,6 +98,7 @@ def plot_sweep(
freq_y = [f for f in freq if f in y_ticks and 1000 <= f <= sample_rate // 2]

figure, axis = plt.subplots(1, 1)
waveform = waveform.cpu() if torch.is_tensor(waveform) else waveform
_, _, _, cax = axis.specgram(waveform[0].numpy(), Fs=sample_rate)
plt.xticks(time, freq_x)
plt.yticks(freq_y, freq_y)
Expand Down Expand Up @@ -144,7 +152,7 @@ def plot_sweep(
waveform = get_sine_sweep(sample_rate)

plot_sweep(waveform, sample_rate, title="Original Waveform")
Audio(waveform.numpy()[0], rate=sample_rate)
Audio(waveform.cpu().numpy()[0], rate=sample_rate)

######################################################################
#
Expand All @@ -157,11 +165,11 @@ def plot_sweep(
# an explanation of how it happens, and why it looks like a reflection.

resample_rate = 32000
resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype)
resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype).to(device)
resampled_waveform = resampler(waveform)

plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform")
Audio(resampled_waveform.numpy()[0], rate=resample_rate)
Audio(resampled_waveform.cpu().numpy()[0], rate=resample_rate)

######################################################################
# Controling resampling quality with parameters
Expand All @@ -182,13 +190,13 @@ def plot_sweep(
sample_rate = 48000
resample_rate = 32000

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6).to(device)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")

######################################################################
#

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128).to(device)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")

######################################################################
Expand All @@ -208,13 +216,13 @@ def plot_sweep(
sample_rate = 48000
resample_rate = 32000

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99)
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99).to(device)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.99")

######################################################################
#

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8)
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8).to(device)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")


Expand All @@ -234,13 +242,13 @@ def plot_sweep(
sample_rate = 48000
resample_rate = 32000

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_hann")
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_hann").to(device)
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")

######################################################################
#

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_kaiser")
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_kaiser").to(device)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")


Expand All @@ -267,7 +275,7 @@ def plot_sweep(
rolloff=0.9475937167399596,
resampling_method="sinc_interp_kaiser",
beta=14.769656459379492,
)
).to(device)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")

######################################################################
Expand All @@ -282,7 +290,7 @@ def plot_sweep(
rolloff=0.85,
resampling_method="sinc_interp_kaiser",
beta=8.555504641634386,
)
).to(device)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")

######################################################################
Expand Down
8 changes: 7 additions & 1 deletion examples/tutorials/forced_alignment_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@
print(torchaudio.__version__)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.accelerator.is_available():
acc = torch.accelerator.current_accelerator()
device = torch.device(acc)
backend = torch.distributed.get_default_backend_for_device(device)
else:
device = torch.device("cpu")
backend = "gloo"
print(device)


Expand Down
8 changes: 7 additions & 1 deletion examples/tutorials/speech_recognition_pipeline_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@
print(torchaudio.__version__)

torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.accelerator.is_available():
acc = torch.accelerator.current_accelerator()
device = torch.device(acc)
backend = torch.distributed.get_default_backend_for_device(device)
else:
device = torch.device("cpu")
backend = "gloo"

print(device)

Expand Down