@@ -22,7 +22,9 @@ def __init__(self, mX, sTarget, nResidual, psTarget = [], pnResidual = [], alpha
2222 self ._alpha = alpha
2323 self ._method = method
2424 self ._iterations = 200
25- self ._lr = 2e-3
25+ self ._lr = 3e-3 #2e-3
26+ self ._hetaplus = 1.2
27+ self ._hetaminus = 0.5
2628
2729 def __call__ (self , reverse = False ):
2830
@@ -255,8 +257,8 @@ def phaseSensitive(self):
255257
256258 def optAlpha (self , initloss ):
257259 """
258- A simple gradiend descent method, to find optimum power-spectral density exponents (alpha)
259- for generalized wiener filtering.
260+ A simple gradiend descent method using the RProp algorithm,
261+ for finding optimum power-spectral density exponents (alpha) for generalized wiener filtering.
260262 Args:
261263 sTarget : (2D ndarray) Magnitude Spectrogram of the target component
262264 nResidual: (2D ndarray) Magnitude Spectrogram of the residual component or a list
@@ -273,8 +275,8 @@ def optAlpha(self, initloss):
273275 numElements = len (slist )
274276 slist = np .asarray (slist )
275277
276- alpha = np .array ([1.2 ] * (numElements )) # Initialize an array of alpha values to be found.
277- dloss = np .array ([0. ] * (numElements )) # Initialize an array of loss functions to be used.
278+ alpha = np .array ([1.15 ] * (numElements )) # Initialize an array of alpha values to be found.
279+ dloss = np .array ([0. ] * (numElements )) # Initialize an array of loss functions to be used.
278280 lrs = np .array ([self ._lr ] * (numElements )) # Initialize an array of learning rates to be applied to each source.
279281
280282 # Begin of otpimization
@@ -291,7 +293,7 @@ def optAlpha(self, initloss):
291293
292294 alpha -= (lrs * dloss )
293295
294- # Make sure of un-wanted values
296+ # Make sure the initial alpha are inside reasonable values
295297 alpha = np .clip (alpha , a_min = 0.5 , a_max = 2. )
296298
297299 # Check IS Loss by computing Xhat
@@ -301,16 +303,25 @@ def optAlpha(self, initloss):
301303
302304 isloss .append (self ._IS (Xhat ))
303305 if (iter > 2 ):
306+ # Apply RProp
307+ if (isloss [- 2 ] - isloss [- 1 ] > 0 ):
308+ lrs *= self ._hetaplus
309+
304310 if (isloss [- 2 ] - isloss [- 1 ] < 0 ):
305- print ('Local Minimum Found' )
306- alpha += (lrs * dloss )
307- break
311+ lrs *= self ._hetaminus
312+
313+ if (iter > 4 ):
314+ if (np .abs (isloss [- 2 ] - isloss [- 1 ]) < 1e-4 and np .abs (isloss [- 3 ] - isloss [- 2 ]) < 1e-4 ):
315+ print ('Local Minimum Found' )
316+ print ('Final Loss: ' + str (isloss [- 1 ]) + ' with characteristic exponent(s): ' + str (alpha ))
317+ break
308318
309319 print ('Loss: ' + str (isloss [- 1 ]) + ' with characteristic exponent(s): ' + str (alpha ))
310320
311321 # Evaluate Xhat for the mask update
312322 self ._mask = np .divide ((slist [0 , :, :] ** alpha [0 ] + self ._eps ), (self ._mX ** self ._alpha + self ._eps ))
313- self ._closs = isloss
323+ self ._closs = isloss [- 1 ]
324+ self ._alpha = alpha
314325
315326 def MWF (self ):
316327 """ Multi-channel Wiener filtering as appears in:
0 commit comments