@@ -32,10 +32,10 @@ def __init__(
32
32
self .f0_max = 1100
33
33
self .f0_mel_min = 1127 * np .log (1 + self .f0_min / 700 )
34
34
self .f0_mel_max = 1127 * np .log (1 + self .f0_max / 700 )
35
- if index_rate != 0 :
35
+ if index_rate != 0 :
36
36
self .index = faiss .read_index (index_path )
37
37
self .big_npy = np .load (npy_path )
38
- print (' index search enabled' )
38
+ print (" index search enabled" )
39
39
self .index_rate = index_rate
40
40
model_path = hubert_path
41
41
print ("load model(s) from {}" .format (model_path ))
@@ -111,11 +111,7 @@ def infer(self, feats: torch.Tensor) -> np.ndarray:
111
111
feats = self .model .final_proj (logits [0 ])
112
112
113
113
####索引优化
114
- if (
115
- hasattr (self ,'index' )
116
- and hasattr (self ,'big_npy' )
117
- and self .index_rate != 0
118
- ):
114
+ if hasattr (self , "index" ) and hasattr (self , "big_npy" ) and self .index_rate != 0 :
119
115
npy = feats [0 ].cpu ().numpy ().astype ("float32" )
120
116
_ , I = self .index .search (npy , 1 )
121
117
npy = self .big_npy [I .squeeze ()].astype ("float16" )
@@ -124,7 +120,7 @@ def infer(self, feats: torch.Tensor) -> np.ndarray:
124
120
+ (1 - self .index_rate ) * feats
125
121
)
126
122
else :
127
- print (' index search FAIL or disabled' )
123
+ print (" index search FAIL or disabled" )
128
124
129
125
feats = F .interpolate (feats .permute (0 , 2 , 1 ), scale_factor = 2 ).permute (0 , 2 , 1 )
130
126
torch .cuda .synchronize ()
0 commit comments