Skip to content

Commit 7c9e7dd

Browse files
committed
Move sklearn import inside kmeans training function
1 parent b349dee commit 7c9e7dd

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/hubert/utils/kmeans.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Tuple
99

1010
import torch
11-
from sklearn.cluster import MiniBatchKMeans
1211
from torch import Tensor
1312

1413
from .common_utils import _get_feat_lens_paths, _get_model_path
@@ -102,6 +101,7 @@ def learn_kmeans(
102101
"""
103102
if not km_dir.exists():
104103
km_dir.mkdir()
104+
from sklearn.cluster import MiniBatchKMeans
105105

106106
km_model = MiniBatchKMeans(
107107
n_clusters=n_clusters,

test/torchaudio_unittest/example/hubert/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
import sys
33

44

5-
# sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "examples", "hubert"))
5+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "examples", "hubert"))

0 commit comments

Comments
 (0)