diff --git a/examples/tutorials/audio_resampling_tutorial.py b/examples/tutorials/audio_resampling_tutorial.py index adca5073c9..4071599d7f 100644 --- a/examples/tutorials/audio_resampling_tutorial.py +++ b/examples/tutorials/audio_resampling_tutorial.py @@ -16,6 +16,13 @@ print(torch.__version__) print(torchaudio.__version__) +if torch.accelerator.is_available(): + acc = torch.accelerator.current_accelerator() + device = torch.device(acc) +else: + device = torch.device("cpu") +print(f"Using device: {device}") + ###################################################################### # Preparation # ----------- @@ -41,7 +48,7 @@ def _get_log_freq(sample_rate, max_sweep_rate, offset): """ start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2) - return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset + return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double, device=device)) - offset def _get_inverse_log_freq(freq, sample_rate, offset): @@ -91,6 +98,7 @@ def plot_sweep( freq_y = [f for f in freq if f in y_ticks and 1000 <= f <= sample_rate // 2] figure, axis = plt.subplots(1, 1) + waveform = waveform.cpu() if torch.is_tensor(waveform) else waveform _, _, _, cax = axis.specgram(waveform[0].numpy(), Fs=sample_rate) plt.xticks(time, freq_x) plt.yticks(freq_y, freq_y) @@ -144,7 +152,7 @@ def plot_sweep( waveform = get_sine_sweep(sample_rate) plot_sweep(waveform, sample_rate, title="Original Waveform") -Audio(waveform.numpy()[0], rate=sample_rate) +Audio(waveform.cpu().numpy()[0], rate=sample_rate) ###################################################################### # @@ -157,11 +165,11 @@ def plot_sweep( # an explanation of how it happens, and why it looks like a reflection. resample_rate = 32000 -resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype) +resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype).to(device) resampled_waveform = resampler(waveform) plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform") -Audio(resampled_waveform.numpy()[0], rate=resample_rate) +Audio(resampled_waveform.cpu().numpy()[0], rate=resample_rate) ###################################################################### # Controling resampling quality with parameters @@ -182,13 +190,13 @@ def plot_sweep( sample_rate = 48000 resample_rate = 32000 -resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6) +resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6).to(device) plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6") ###################################################################### # -resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128) +resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128).to(device) plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128") ###################################################################### @@ -208,13 +216,13 @@ def plot_sweep( sample_rate = 48000 resample_rate = 32000 -resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99) +resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99).to(device) plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.99") ###################################################################### # -resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8) +resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8).to(device) plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8") @@ -234,13 +242,13 @@ def plot_sweep( sample_rate = 48000 resample_rate = 32000 -resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_hann") +resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_hann").to(device) plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default") ###################################################################### # -resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_kaiser") +resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_kaiser").to(device) plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default") @@ -267,7 +275,7 @@ def plot_sweep( rolloff=0.9475937167399596, resampling_method="sinc_interp_kaiser", beta=14.769656459379492, -) +).to(device) plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)") ###################################################################### @@ -282,7 +290,7 @@ def plot_sweep( rolloff=0.85, resampling_method="sinc_interp_kaiser", beta=8.555504641634386, -) +).to(device) plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)") ###################################################################### diff --git a/examples/tutorials/forced_alignment_tutorial.py b/examples/tutorials/forced_alignment_tutorial.py index 7fa7c86dc3..969af1fde5 100644 --- a/examples/tutorials/forced_alignment_tutorial.py +++ b/examples/tutorials/forced_alignment_tutorial.py @@ -47,7 +47,13 @@ print(torchaudio.__version__) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if torch.accelerator.is_available(): + acc = torch.accelerator.current_accelerator() + device = torch.device(acc) + backend = torch.distributed.get_default_backend_for_device(device) +else: + device = torch.device("cpu") + backend = "gloo" print(device) diff --git a/examples/tutorials/speech_recognition_pipeline_tutorial.py b/examples/tutorials/speech_recognition_pipeline_tutorial.py index 2c8dfc752b..f5654ba7d4 100644 --- a/examples/tutorials/speech_recognition_pipeline_tutorial.py +++ b/examples/tutorials/speech_recognition_pipeline_tutorial.py @@ -42,7 +42,13 @@ print(torchaudio.__version__) torch.random.manual_seed(0) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if torch.accelerator.is_available(): + acc = torch.accelerator.current_accelerator() + device = torch.device(acc) + backend = torch.distributed.get_default_backend_for_device(device) +else: + device = torch.device("cpu") + backend = "gloo" print(device)