Skip to content

Commit e0a36fc

Browse files
authored
Fix CutSampler initialization for newer PyTorch versions (#1543)
* Fix CutSampler initialization for newer PyTorch versions * Update unit tests for newer python and pytorch versions * Unfreeze some test package versions * Remove torchscriptability checks for feature extractors
1 parent 434e935 commit e0a36fc

File tree

5 files changed

+13
-63
lines changed

5 files changed

+13
-63
lines changed

.github/workflows/unit_tests.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@ jobs:
1616
strategy:
1717
matrix:
1818
include:
19-
- python-version: "3.8" # note: no torchaudio
20-
torch-install-cmd: "pip install torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cpu"
21-
extra_deps: ""
22-
- python-version: "3.9"
23-
torch-install-cmd: "pip install torch==2.4 torchaudio==2.4 --extra-index-url https://download.pytorch.org/whl/cpu"
24-
extra_deps: "multi-storage-client==0.16.0"
2519
- python-version: "3.10" # note: no torchaudio
2620
torch-install-cmd: "pip install torch==2.5 --extra-index-url https://download.pytorch.org/whl/cpu"
2721
extra_deps: "multi-storage-client==0.16.0"
@@ -31,6 +25,12 @@ jobs:
3125
- python-version: "3.12" # note: no torchaudio
3226
torch-install-cmd: "pip install torch==2.7 --index-url https://download.pytorch.org/whl/cpu"
3327
extra_deps: "multi-storage-client==0.16.0"
28+
- python-version: "3.13" # note: no torchaudio
29+
torch-install-cmd: "pip install torch==2.9 --index-url https://download.pytorch.org/whl/cpu"
30+
extra_deps: "multi-storage-client==0.16.0"
31+
- python-version: "3.14" # note: no torchaudio
32+
torch-install-cmd: "pip install torch==2.9 --index-url https://download.pytorch.org/whl/cpu"
33+
extra_deps: "multi-storage-client==0.16.0"
3434

3535
fail-fast: false
3636

lhotse/dataset/sampling/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ def __init__(
7474
:param rank: Index of distributed node. We will try to infer it by default.
7575
:param seed: Random seed used to consistently shuffle the dataset across different processes.
7676
"""
77-
super().__init__(
78-
data_source=None
79-
) # the "data_source" arg is not used in Sampler...
77+
super().__init__()
8078
self.drop_last = drop_last
8179
self.shuffle = shuffle
8280
self.seed = seed

lhotse/dataset/sampling/bucketing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def _create_buckets_equal_duration_single(
371371
372372
See also: :meth:`.create_buckets_from_duration_percentiles`.
373373
"""
374-
total_duration = np.sum(c.duration for c in cuts)
374+
total_duration = np.sum([c.duration for c in cuts])
375375
bucket_duration = total_duration / num_buckets
376376
# Define the order for adding cuts. We start at the beginning, then go to
377377
# the end, and work our way to the middle. Once in the middle we distribute

setup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,12 @@ def mark_lhotse_version(version: str) -> None:
170170

171171
docs_require = (project_root / "docs" / "requirements.txt").read_text().splitlines()
172172
tests_require = [
173-
"pytest==7.1.3",
174-
"pytest-forked==1.4.0",
175-
"pytest-xdist==2.5.0",
176-
"pytest-cov==4.0.0",
173+
"pytest",
174+
"pytest-forked",
175+
"pytest-xdist",
176+
"pytest-cov",
177177
"flake8==5.0.4",
178-
"coverage==6.5.0",
178+
"coverage",
179179
"hypothesis==6.56.0",
180180
"black==22.3.0",
181181
"isort==5.10.1",

test/features/test_kaldi_layers.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -66,54 +66,6 @@ def test_wav2mfcc(deterministic_rng):
6666
assert y.dtype == torch.float32
6767

6868

69-
def test_wav2win_is_torchscriptable(deterministic_rng):
70-
x = torch.randn(1, 16000, dtype=torch.float32)
71-
t = torch.jit.script(Wav2Win())
72-
y, _ = t(x)
73-
assert y.shape == torch.Size([1, 100, 400])
74-
assert y.dtype == torch.float32
75-
76-
77-
def test_wav2fft_is_torchscriptable(deterministic_rng):
78-
x = torch.randn(1, 16000, dtype=torch.float32)
79-
t = torch.jit.script(Wav2FFT())
80-
y = t(x)
81-
assert y.shape == torch.Size([1, 100, 257])
82-
assert y.dtype == torch.complex64
83-
84-
85-
def test_wav2spec_is_torchscriptable(deterministic_rng):
86-
x = torch.randn(1, 16000, dtype=torch.float32)
87-
t = torch.jit.script(Wav2Spec())
88-
y = t(x)
89-
assert y.shape == torch.Size([1, 100, 257])
90-
assert y.dtype == torch.float32
91-
92-
93-
def test_wav2logspec_is_torchscriptable(deterministic_rng):
94-
x = torch.randn(1, 16000, dtype=torch.float32)
95-
t = torch.jit.script(Wav2LogSpec())
96-
y = t(x)
97-
assert y.shape == torch.Size([1, 100, 257])
98-
assert y.dtype == torch.float32
99-
100-
101-
def test_wav2logfilterbank_is_torchscriptable(deterministic_rng):
102-
x = torch.randn(1, 16000, dtype=torch.float32)
103-
t = torch.jit.script(Wav2LogFilterBank())
104-
y = t(x)
105-
assert y.shape == torch.Size([1, 100, 80])
106-
assert y.dtype == torch.float32
107-
108-
109-
def test_wav2mfcc_is_torchscriptable(deterministic_rng):
110-
x = torch.randn(1, 16000, dtype=torch.float32)
111-
t = torch.jit.script(Wav2MFCC())
112-
y = t(x)
113-
assert y.shape == torch.Size([1, 100, 13])
114-
assert y.dtype == torch.float32
115-
116-
11769
def test_strided_waveform_batch_streaming_snip_edges_false(deterministic_rng):
11870
x = torch.arange(16000).unsqueeze(0)
11971
window_length = 400

0 commit comments

Comments
 (0)