Skip to content

Commit 68ac03d

Browse files
committed
Removed numba, made all pyfftw imports conditional
1 parent 8b72960 commit 68ac03d

File tree

7 files changed

+38
-200
lines changed

7 files changed

+38
-200
lines changed

Dockerfile.safe

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,8 @@ RUN mamba install -y python \
7777
h5py \
7878
"tensorflow>=2.4.0" \
7979
pyqtgraph \
80-
pyfftw \
8180
pandas \
82-
versioneer \
83-
numba; sync && \
81+
versioneer; sync && \
8482
chmod -R a+rX /usr/local/miniconda; sync && \
8583
chmod +x /usr/local/miniconda/bin/*; sync && \
8684
conda-build purge-all; sync && \

capcalc/filter.py

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,27 @@
2323
2424
"""
2525
import sys
26+
import warnings
2627

2728
import matplotlib.pyplot as plt
2829
import numpy as np
29-
import pyfftw
3030

31-
# from numba import jit
31+
with warnings.catch_warnings():
32+
warnings.simplefilter("ignore")
33+
try:
34+
import pyfftw
35+
except ImportError:
36+
pyfftwpresent = False
37+
else:
38+
pyfftwpresent = True
39+
3240
from scipy import fftpack, ndimage, signal
3341
from scipy.signal import savgol_filter
3442

35-
# import warnings
36-
43+
if pyfftwpresent:
44+
fftpack = pyfftw.interfaces.scipy_fftpack
45+
pyfftw.interfaces.cache.enable()
3746

38-
# fftpack = pyfftw.interfaces.scipy_fftpack
39-
# pyfftw.interfaces.cache.enable()
40-
41-
# ---------------------------------------- Global constants -------------------------------------------
42-
donotusenumba = True
4347

4448
# ----------------------------------------- Conditional imports ---------------------------------------
4549
try:
@@ -49,23 +53,6 @@
4953
except ImportError:
5054
memprofilerexists = False
5155

52-
53-
# ----------------------------------------- Conditional jit handling ----------------------------------
54-
def conditionaljit():
55-
def resdec(f):
56-
global donotusenumba
57-
if donotusenumba:
58-
return f
59-
return jit(f, nopython=False)
60-
61-
return resdec
62-
63-
64-
def disablenumba():
65-
global donotusenumba
66-
donotusenumba = True
67-
68-
6956
# --------------------------- Filtering functions -------------------------------------------------
7057
# NB: No automatic padding for precalculated filters
7158

@@ -177,7 +164,8 @@ def ssmooth(xsize, ysize, zsize, sigma, inputdata):
177164

178165

179166
# - butterworth filters
180-
@conditionaljit()
167+
168+
181169
def dolpfiltfilt(Fs, upperpass, inputdata, order, padlen=20, cyclic=False, debug=False):
182170
r"""Performs a bidirectional (zero phase) Butterworth lowpass filter on an input vector
183171
and returns the result. Ends are padded to reduce transients.
@@ -235,7 +223,6 @@ def dolpfiltfilt(Fs, upperpass, inputdata, order, padlen=20, cyclic=False, debug
235223
).astype(np.float64)
236224

237225

238-
@conditionaljit()
239226
def dohpfiltfilt(Fs, lowerpass, inputdata, order, padlen=20, cyclic=False, debug=False):
240227
r"""Performs a bidirectional (zero phase) Butterworth highpass filter on an input vector
241228
and returns the result. Ends are padded to reduce transients.
@@ -292,7 +279,6 @@ def dohpfiltfilt(Fs, lowerpass, inputdata, order, padlen=20, cyclic=False, debug
292279
)
293280

294281

295-
@conditionaljit()
296282
def dobpfiltfilt(Fs, lowerpass, upperpass, inputdata, order, padlen=20, cyclic=False, debug=False):
297283
r"""Performs a bidirectional (zero phase) Butterworth bandpass filter on an input vector
298284
and returns the result. Ends are padded to reduce transients.
@@ -419,7 +405,6 @@ def getlpfftfunc(Fs, upperpass, inputdata, debug=False):
419405
return transferfunc
420406

421407

422-
@conditionaljit()
423408
def dolpfftfilt(Fs, upperpass, inputdata, padlen=20, cyclic=False, debug=False):
424409
r"""Performs an FFT brickwall lowpass filter on an input vector
425410
and returns the result. Ends are padded to reduce transients.
@@ -462,7 +447,6 @@ def dolpfftfilt(Fs, upperpass, inputdata, padlen=20, cyclic=False, debug=False):
462447
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)
463448

464449

465-
@conditionaljit()
466450
def dohpfftfilt(Fs, lowerpass, inputdata, padlen=20, cyclic=False, debug=False):
467451
r"""Performs an FFT brickwall highpass filter on an input vector
468452
and returns the result. Ends are padded to reduce transients.
@@ -505,7 +489,6 @@ def dohpfftfilt(Fs, lowerpass, inputdata, padlen=20, cyclic=False, debug=False):
505489
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)
506490

507491

508-
@conditionaljit()
509492
def dobpfftfilt(Fs, lowerpass, upperpass, inputdata, padlen=20, cyclic=False, debug=False):
510493
r"""Performs an FFT brickwall bandpass filter on an input vector
511494
and returns the result. Ends are padded to reduce transients.
@@ -555,7 +538,8 @@ def dobpfftfilt(Fs, lowerpass, upperpass, inputdata, padlen=20, cyclic=False, de
555538

556539

557540
# - fft trapezoidal filters
558-
@conditionaljit()
541+
542+
559543
def getlptrapfftfunc(Fs, upperpass, upperstop, inputdata, debug=False):
560544
r"""Generates a trapezoidal lowpass transfer function.
561545
@@ -608,7 +592,6 @@ def getlptrapfftfunc(Fs, upperpass, upperstop, inputdata, debug=False):
608592
return transferfunc
609593

610594

611-
@conditionaljit()
612595
def getlptransfunc(Fs, inputdata, upperpass=None, upperstop=None, type="brickwall", debug=False):
613596
if upperpass is None:
614597
print("getlptransfunc: upperpass must be specified")
@@ -693,7 +676,6 @@ def gethptransfunc(Fs, inputdata, lowerstop=None, lowerpass=None, type="brickwal
693676
return transferfunc
694677

695678

696-
@conditionaljit()
697679
def dolptransfuncfilt(
698680
Fs,
699681
inputdata,
@@ -757,7 +739,6 @@ def dolptransfuncfilt(
757739
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)
758740

759741

760-
@conditionaljit()
761742
def dohptransfuncfilt(
762743
Fs,
763744
inputdata,
@@ -827,7 +808,6 @@ def dohptransfuncfilt(
827808
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)
828809

829810

830-
@conditionaljit()
831811
def dobptransfuncfilt(
832812
Fs,
833813
inputdata,
@@ -908,7 +888,6 @@ def dobptransfuncfilt(
908888
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)
909889

910890

911-
@conditionaljit()
912891
def dolptrapfftfilt(Fs, upperpass, upperstop, inputdata, padlen=20, cyclic=False, debug=False):
913892
r"""Performs an FFT filter with a trapezoidal lowpass transfer
914893
function on an input vector and returns the result. Ends are padded to reduce transients.
@@ -955,7 +934,6 @@ def dolptrapfftfilt(Fs, upperpass, upperstop, inputdata, padlen=20, cyclic=False
955934
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)
956935

957936

958-
@conditionaljit()
959937
def dohptrapfftfilt(Fs, lowerstop, lowerpass, inputdata, padlen=20, cyclic=False, debug=False):
960938
r"""Performs an FFT filter with a trapezoidal highpass transfer
961939
function on an input vector and returns the result. Ends are padded to reduce transients.
@@ -1002,7 +980,6 @@ def dohptrapfftfilt(Fs, lowerstop, lowerpass, inputdata, padlen=20, cyclic=False
1002980
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)
1003981

1004982

1005-
@conditionaljit()
1006983
def dobptrapfftfilt(
1007984
Fs,
1008985
lowerstop,
@@ -1293,7 +1270,6 @@ def csdfilter(obsdata, commondata, padlen=20, cyclic=False, debug=False):
12931270
return unpadvec(fftpack.ifft(obsdata_trans).real, padlen=padlen)
12941271

12951272

1296-
@conditionaljit()
12971273
def arb_pass(
12981274
Fs,
12991275
inputdata,

capcalc/fit.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,15 @@
2323

2424
import matplotlib.pyplot as plt
2525
import numpy as np
26-
import pyfftw
2726
import scipy as sp
2827
import scipy.special as sps
29-
30-
# from numba import jit
3128
from scipy.signal import find_peaks, hilbert
3229

3330
import capcalc.util as ccalc_util
3431

35-
fftpack = pyfftw.interfaces.scipy_fftpack
36-
pyfftw.interfaces.cache.enable()
37-
3832
# ---------------------------------------- Global constants -------------------------------------------
3933
defaultbutterorder = 6
4034
MAXLINES = 10000000
41-
donotbeaggressive = True
4235

4336
# ----------------------------------------- Conditional imports ---------------------------------------
4437
try:
@@ -48,31 +41,6 @@
4841
except ImportError:
4942
memprofilerexists = False
5043

51-
donotusenumba = True
52-
53-
54-
def conditionaljit():
55-
def resdec(f):
56-
if donotusenumba:
57-
return f
58-
return jit(f, nopython=False)
59-
60-
return resdec
61-
62-
63-
def conditionaljit2():
64-
def resdec(f):
65-
if donotusenumba or donotbeaggressive:
66-
return f
67-
return jit(f, nopython=False)
68-
69-
return resdec
70-
71-
72-
def disablenumba():
73-
global donotusenumba
74-
donotusenumba = True
75-
7644

7745
# --------------------------- Fitting functions -------------------------------------------------
7846
def gaussresidualssk(p, y, x):
@@ -108,7 +76,6 @@ def gaussskresiduals(p, y, x):
10876
return y - gausssk_eval(x, p)
10977

11078

111-
@conditionaljit()
11279
def gaussresiduals(p, y, x):
11380
"""
11481
@@ -174,7 +141,6 @@ def gausssk_eval(x, p):
174141
return p[0] * sp.stats.norm.pdf(t) * sp.stats.norm.cdf(p[3] * t)
175142

176143

177-
@conditionaljit()
178144
def kaiserbessel_eval(x, p):
179145
"""
180146
@@ -201,7 +167,6 @@ def kaiserbessel_eval(x, p):
201167
)
202168

203169

204-
@conditionaljit()
205170
def gauss_eval(x, p):
206171
"""
207172
@@ -254,7 +219,6 @@ def risetime_eval_loop(x, p):
254219
return r
255220

256221

257-
@conditionaljit()
258222
def trapezoid_eval(x, toplength, p):
259223
"""
260224
@@ -277,7 +241,6 @@ def trapezoid_eval(x, toplength, p):
277241
return p[1] * (np.exp(-(corrx - toplength) / p[3]))
278242

279243

280-
@conditionaljit()
281244
def risetime_eval(x, p):
282245
"""
283246
@@ -367,7 +330,8 @@ def locpeak(data, samplerate, lastpeaktime, winsizeinsecs=5.0, thresh=0.75, hyst
367330

368331

369332
# generate the polynomial fit timecourse from the coefficients
370-
@conditionaljit()
333+
334+
371335
def trendgen(thexvals, thefitcoffs, demean):
372336
"""
373337
@@ -416,7 +380,6 @@ def detrend(inputdata, order=1, demean=False):
416380
return inputdata - thefittc
417381

418382

419-
@conditionaljit()
420383
def findfirstabove(theyvals, thevalue):
421384
"""
422385
@@ -640,7 +603,6 @@ def territorydecomp(
640603
return fitmap, thecoffs, theRs
641604

642605

643-
@conditionaljit()
644606
def refinepeak_quad(x, y, peakindex, stride=1):
645607
# first make sure this actually is a peak
646608
ismax = None
@@ -670,7 +632,6 @@ def refinepeak_quad(x, y, peakindex, stride=1):
670632
return peakloc, peakval, peakwidth, ismax, badfit
671633

672634

673-
@conditionaljit2()
674635
def findmaxlag_gauss(
675636
thexcorr_x,
676637
thexcorr_y,
@@ -924,7 +885,6 @@ def findmaxlag_gauss(
924885
return maxindex, maxlag, maxval, maxsigma, maskval, failreason, fitstart, fitend
925886

926887

927-
@conditionaljit2()
928888
def maxindex_noedge(thexcorr_x, thexcorr_y, bipolar=False):
929889
"""
930890
@@ -965,8 +925,6 @@ def maxindex_noedge(thexcorr_x, thexcorr_y, bipolar=False):
965925
return maxindex, flipfac
966926

967927

968-
# disabled conditionaljit on 11/8/16. This causes crashes on some machines (but not mine, strangely enough)
969-
@conditionaljit2()
970928
def findmaxlag_gauss_rev(
971929
thexcorr_x,
972930
thexcorr_y,
@@ -1266,7 +1224,6 @@ def findmaxlag_gauss_rev(
12661224
)
12671225

12681226

1269-
@conditionaljit2()
12701227
def findmaxlag_quad(
12711228
thexcorr_x,
12721229
thexcorr_y,

0 commit comments

Comments
 (0)