Skip to content

Commit c819663

Browse files
committed
Lazy scipy special, stats and linalg imports
1 parent 7fa8d58 commit c819663

File tree

4 files changed

+179
-74
lines changed

4 files changed

+179
-74
lines changed

pytensor/scalar/math.py

Lines changed: 90 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from textwrap import dedent
1010

1111
import numpy as np
12-
import scipy.special
13-
import scipy.stats
1412

1513
from pytensor.configdefaults import config
1614
from pytensor.gradient import grad_not_implemented, grad_undefined
@@ -54,7 +52,9 @@ class Erf(UnaryScalarOp):
5452
nfunc_spec = ("scipy.special.erf", 1, 1)
5553

5654
def impl(self, x):
57-
return scipy.special.erf(x)
55+
from scipy.special import erf
56+
57+
return erf(x)
5858

5959
def L_op(self, inputs, outputs, grads):
6060
(x,) = inputs
@@ -88,7 +88,9 @@ class Erfc(UnaryScalarOp):
8888
nfunc_spec = ("scipy.special.erfc", 1, 1)
8989

9090
def impl(self, x):
91-
return scipy.special.erfc(x)
91+
from scipy.special import erfc
92+
93+
return erfc(x)
9294

9395
def L_op(self, inputs, outputs, grads):
9496
(x,) = inputs
@@ -137,7 +139,9 @@ class Erfcx(UnaryScalarOp):
137139
nfunc_spec = ("scipy.special.erfcx", 1, 1)
138140

139141
def impl(self, x):
140-
return scipy.special.erfcx(x)
142+
from scipy.special import erfcx
143+
144+
return erfcx(x)
141145

142146
def L_op(self, inputs, outputs, grads):
143147
(x,) = inputs
@@ -193,7 +197,9 @@ class Erfinv(UnaryScalarOp):
193197
nfunc_spec = ("scipy.special.erfinv", 1, 1)
194198

195199
def impl(self, x):
196-
return scipy.special.erfinv(x)
200+
from scipy.special import erfinv
201+
202+
return erfinv(x)
197203

198204
def L_op(self, inputs, outputs, grads):
199205
(x,) = inputs
@@ -228,7 +234,9 @@ class Erfcinv(UnaryScalarOp):
228234
nfunc_spec = ("scipy.special.erfcinv", 1, 1)
229235

230236
def impl(self, x):
231-
return scipy.special.erfcinv(x)
237+
from scipy.special import erfcinv
238+
239+
return erfcinv(x)
232240

233241
def L_op(self, inputs, outputs, grads):
234242
(x,) = inputs
@@ -264,7 +272,9 @@ class Owens_t(BinaryScalarOp):
264272

265273
@staticmethod
266274
def st_impl(h, a):
267-
return scipy.special.owens_t(h, a)
275+
from scipy.special import owens_t
276+
277+
return owens_t(h, a)
268278

269279
def impl(self, h, a):
270280
return Owens_t.st_impl(h, a)
@@ -293,7 +303,9 @@ class Gamma(UnaryScalarOp):
293303

294304
@staticmethod
295305
def st_impl(x):
296-
return scipy.special.gamma(x)
306+
from scipy.special import gamma
307+
308+
return gamma(x)
297309

298310
def impl(self, x):
299311
return Gamma.st_impl(x)
@@ -332,7 +344,9 @@ class GammaLn(UnaryScalarOp):
332344

333345
@staticmethod
334346
def st_impl(x):
335-
return scipy.special.gammaln(x)
347+
from scipy.special import gammaln
348+
349+
return gammaln(x)
336350

337351
def impl(self, x):
338352
return GammaLn.st_impl(x)
@@ -376,7 +390,9 @@ class Psi(UnaryScalarOp):
376390

377391
@staticmethod
378392
def st_impl(x):
379-
return scipy.special.psi(x)
393+
from scipy.special import psi
394+
395+
return psi(x)
380396

381397
def impl(self, x):
382398
return Psi.st_impl(x)
@@ -467,7 +483,9 @@ class TriGamma(UnaryScalarOp):
467483

468484
@staticmethod
469485
def st_impl(x):
470-
return scipy.special.polygamma(1, x)
486+
from scipy.special import polygamma
487+
488+
return polygamma(1, x)
471489

472490
def impl(self, x):
473491
return TriGamma.st_impl(x)
@@ -570,7 +588,9 @@ def output_types_preference(n_type, x_type):
570588

571589
@staticmethod
572590
def st_impl(n, x):
573-
return scipy.special.polygamma(n, x)
591+
from scipy.special import polygamma
592+
593+
return polygamma(n, x)
574594

575595
def impl(self, n, x):
576596
return PolyGamma.st_impl(n, x)
@@ -602,7 +622,9 @@ class Chi2SF(BinaryScalarOp):
602622

603623
@staticmethod
604624
def st_impl(x, k):
605-
return scipy.stats.chi2.sf(x, k)
625+
from scipy.stats import chi2
626+
627+
return chi2.sf(x, k)
606628

607629
def impl(self, x, k):
608630
return Chi2SF.st_impl(x, k)
@@ -645,7 +667,9 @@ class GammaInc(BinaryScalarOp):
645667

646668
@staticmethod
647669
def st_impl(k, x):
648-
return scipy.special.gammainc(k, x)
670+
from scipy.special import gammainc
671+
672+
return gammainc(k, x)
649673

650674
def impl(self, k, x):
651675
return GammaInc.st_impl(k, x)
@@ -696,7 +720,9 @@ class GammaIncC(BinaryScalarOp):
696720

697721
@staticmethod
698722
def st_impl(k, x):
699-
return scipy.special.gammaincc(k, x)
723+
from scipy.special import gammaincc
724+
725+
return gammaincc(k, x)
700726

701727
def impl(self, k, x):
702728
return GammaIncC.st_impl(k, x)
@@ -747,7 +773,9 @@ class GammaIncInv(BinaryScalarOp):
747773

748774
@staticmethod
749775
def st_impl(k, x):
750-
return scipy.special.gammaincinv(k, x)
776+
from scipy.special import gammaincinv
777+
778+
return gammaincinv(k, x)
751779

752780
def impl(self, k, x):
753781
return GammaIncInv.st_impl(k, x)
@@ -776,7 +804,9 @@ class GammaIncCInv(BinaryScalarOp):
776804

777805
@staticmethod
778806
def st_impl(k, x):
779-
return scipy.special.gammainccinv(k, x)
807+
from scipy.special import gammainccinv
808+
809+
return gammainccinv(k, x)
780810

781811
def impl(self, k, x):
782812
return GammaIncCInv.st_impl(k, x)
@@ -1015,7 +1045,9 @@ class GammaU(BinaryScalarOp):
10151045

10161046
@staticmethod
10171047
def st_impl(k, x):
1018-
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
1048+
from scipy.special import gamma, gammaincc
1049+
1050+
return gammaincc(k, x) * gamma(k)
10191051

10201052
def impl(self, k, x):
10211053
return GammaU.st_impl(k, x)
@@ -1051,7 +1083,9 @@ class GammaL(BinaryScalarOp):
10511083

10521084
@staticmethod
10531085
def st_impl(k, x):
1054-
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
1086+
from scipy.special import gamma, gammainc
1087+
1088+
return gammainc(k, x) * gamma(k)
10551089

10561090
def impl(self, k, x):
10571091
return GammaL.st_impl(k, x)
@@ -1087,7 +1121,9 @@ class Jv(BinaryScalarOp):
10871121

10881122
@staticmethod
10891123
def st_impl(v, x):
1090-
return scipy.special.jv(v, x)
1124+
from scipy.special import jv
1125+
1126+
return jv(v, x)
10911127

10921128
def impl(self, v, x):
10931129
return self.st_impl(v, x)
@@ -1116,7 +1152,9 @@ class J1(UnaryScalarOp):
11161152

11171153
@staticmethod
11181154
def st_impl(x):
1119-
return scipy.special.j1(x)
1155+
from scipy.special import j1
1156+
1157+
return j1(x)
11201158

11211159
def impl(self, x):
11221160
return self.st_impl(x)
@@ -1147,7 +1185,9 @@ class J0(UnaryScalarOp):
11471185

11481186
@staticmethod
11491187
def st_impl(x):
1150-
return scipy.special.j0(x)
1188+
from scipy.special import j0
1189+
1190+
return j0(x)
11511191

11521192
def impl(self, x):
11531193
return self.st_impl(x)
@@ -1178,7 +1218,9 @@ class Iv(BinaryScalarOp):
11781218

11791219
@staticmethod
11801220
def st_impl(v, x):
1181-
return scipy.special.iv(v, x)
1221+
from scipy.special import iv
1222+
1223+
return iv(v, x)
11821224

11831225
def impl(self, v, x):
11841226
return self.st_impl(v, x)
@@ -1207,7 +1249,9 @@ class I1(UnaryScalarOp):
12071249

12081250
@staticmethod
12091251
def st_impl(x):
1210-
return scipy.special.i1(x)
1252+
from scipy.special import i1
1253+
1254+
return i1(x)
12111255

12121256
def impl(self, x):
12131257
return self.st_impl(x)
@@ -1233,7 +1277,9 @@ class I0(UnaryScalarOp):
12331277

12341278
@staticmethod
12351279
def st_impl(x):
1236-
return scipy.special.i0(x)
1280+
from scipy.special import i0
1281+
1282+
return i0(x)
12371283

12381284
def impl(self, x):
12391285
return self.st_impl(x)
@@ -1259,7 +1305,9 @@ class Ive(BinaryScalarOp):
12591305

12601306
@staticmethod
12611307
def st_impl(v, x):
1262-
return scipy.special.ive(v, x)
1308+
from scipy.special import ive
1309+
1310+
return ive(v, x)
12631311

12641312
def impl(self, v, x):
12651313
return self.st_impl(v, x)
@@ -1288,7 +1336,9 @@ class Kve(BinaryScalarOp):
12881336

12891337
@staticmethod
12901338
def st_impl(v, x):
1291-
return scipy.special.kve(v, x)
1339+
from scipy.special import kve
1340+
1341+
return kve(v, x)
12921342

12931343
def impl(self, v, x):
12941344
return self.st_impl(v, x)
@@ -1321,7 +1371,9 @@ class Sigmoid(UnaryScalarOp):
13211371
nfunc_spec = ("scipy.special.expit", 1, 1)
13221372

13231373
def impl(self, x):
1324-
return scipy.special.expit(x)
1374+
from scipy.special import expit
1375+
1376+
return expit(x)
13251377

13261378
def grad(self, inp, grads):
13271379
(x,) = inp
@@ -1496,7 +1548,9 @@ class BetaInc(ScalarOp):
14961548
nfunc_spec = ("scipy.special.betainc", 3, 1)
14971549

14981550
def impl(self, a, b, x):
1499-
return scipy.special.betainc(a, b, x)
1551+
from scipy.special import betainc
1552+
1553+
return betainc(a, b, x)
15001554

15011555
def grad(self, inp, grads):
15021556
a, b, x = inp
@@ -1756,7 +1810,9 @@ class BetaIncInv(ScalarOp):
17561810
nfunc_spec = ("scipy.special.betaincinv", 3, 1)
17571811

17581812
def impl(self, a, b, x):
1759-
return scipy.special.betaincinv(a, b, x)
1813+
from scipy.special import betaincinv
1814+
1815+
return betaincinv(a, b, x)
17601816

17611817
def grad(self, inputs, grads):
17621818
(a, b, x) = inputs
@@ -1796,7 +1852,9 @@ class Hyp2F1(ScalarOp):
17961852

17971853
@staticmethod
17981854
def st_impl(a, b, c, z):
1799-
return scipy.special.hyp2f1(a, b, c, z)
1855+
from scipy.special import hyp2f1
1856+
1857+
return hyp2f1(a, b, c, z)
18001858

18011859
def impl(self, a, b, c, z):
18021860
return Hyp2F1.st_impl(a, b, c, z)

0 commit comments

Comments
 (0)