|
| 1 | +import sys |
| 2 | + |
| 3 | +import pytensor.scalar as ps |
| 4 | +from pytensor.scalar import ScalarOp |
| 5 | +from pytensor.xtensor.vectorization import XElemwise |
| 6 | + |
| 7 | + |
| 8 | +this_module = sys.modules[__name__] |
| 9 | + |
| 10 | + |
| 11 | +def _as_xelemwise(core_op: ScalarOp) -> XElemwise: |
| 12 | + out = XElemwise(core_op) |
| 13 | + out.__doc__ = f"Ufunc version of {core_op} for XTensorVariables" |
| 14 | + return out |
| 15 | + |
| 16 | + |
| 17 | +abs = _as_xelemwise(ps.abs) |
| 18 | +add = _as_xelemwise(ps.add) |
| 19 | +logical_and = bitwise_and = and_ = _as_xelemwise(ps.and_) |
| 20 | +angle = _as_xelemwise(ps.angle) |
| 21 | +arccos = _as_xelemwise(ps.arccos) |
| 22 | +arccosh = _as_xelemwise(ps.arccosh) |
| 23 | +arcsin = _as_xelemwise(ps.arcsin) |
| 24 | +arcsinh = _as_xelemwise(ps.arcsinh) |
| 25 | +arctan = _as_xelemwise(ps.arctan) |
| 26 | +arctan2 = _as_xelemwise(ps.arctan2) |
| 27 | +arctanh = _as_xelemwise(ps.arctanh) |
| 28 | +betainc = _as_xelemwise(ps.betainc) |
| 29 | +betaincinv = _as_xelemwise(ps.betaincinv) |
| 30 | +ceil = _as_xelemwise(ps.ceil) |
| 31 | +clip = _as_xelemwise(ps.clip) |
| 32 | +complex = _as_xelemwise(ps.complex) |
| 33 | +conjugate = conj = _as_xelemwise(ps.conj) |
| 34 | +cos = _as_xelemwise(ps.cos) |
| 35 | +cosh = _as_xelemwise(ps.cosh) |
| 36 | +deg2rad = _as_xelemwise(ps.deg2rad) |
| 37 | +equal = eq = _as_xelemwise(ps.eq) |
| 38 | +erf = _as_xelemwise(ps.erf) |
| 39 | +erfc = _as_xelemwise(ps.erfc) |
| 40 | +erfcinv = _as_xelemwise(ps.erfcinv) |
| 41 | +erfcx = _as_xelemwise(ps.erfcx) |
| 42 | +erfinv = _as_xelemwise(ps.erfinv) |
| 43 | +exp = _as_xelemwise(ps.exp) |
| 44 | +exp2 = _as_xelemwise(ps.exp2) |
| 45 | +expm1 = _as_xelemwise(ps.expm1) |
| 46 | +floor = _as_xelemwise(ps.floor) |
| 47 | +floor_divide = floor_div = int_div = _as_xelemwise(ps.int_div) |
| 48 | +gamma = _as_xelemwise(ps.gamma) |
| 49 | +gammainc = _as_xelemwise(ps.gammainc) |
| 50 | +gammaincc = _as_xelemwise(ps.gammaincc) |
| 51 | +gammainccinv = _as_xelemwise(ps.gammainccinv) |
| 52 | +gammaincinv = _as_xelemwise(ps.gammaincinv) |
| 53 | +gammal = _as_xelemwise(ps.gammal) |
| 54 | +gammaln = _as_xelemwise(ps.gammaln) |
| 55 | +gammau = _as_xelemwise(ps.gammau) |
| 56 | +greater_equal = ge = _as_xelemwise(ps.ge) |
| 57 | +greater = gt = _as_xelemwise(ps.gt) |
| 58 | +hyp2f1 = _as_xelemwise(ps.hyp2f1) |
| 59 | +i0 = _as_xelemwise(ps.i0) |
| 60 | +i1 = _as_xelemwise(ps.i1) |
| 61 | +identity = _as_xelemwise(ps.identity) |
| 62 | +imag = _as_xelemwise(ps.imag) |
| 63 | +logical_not = bitwise_invert = bitwise_not = invert = _as_xelemwise(ps.invert) |
| 64 | +isinf = _as_xelemwise(ps.isinf) |
| 65 | +isnan = _as_xelemwise(ps.isnan) |
| 66 | +iv = _as_xelemwise(ps.iv) |
| 67 | +ive = _as_xelemwise(ps.ive) |
| 68 | +j0 = _as_xelemwise(ps.j0) |
| 69 | +j1 = _as_xelemwise(ps.j1) |
| 70 | +jv = _as_xelemwise(ps.jv) |
| 71 | +kve = _as_xelemwise(ps.kve) |
| 72 | +less_equal = le = _as_xelemwise(ps.le) |
| 73 | +log = _as_xelemwise(ps.log) |
| 74 | +log10 = _as_xelemwise(ps.log10) |
| 75 | +log1mexp = _as_xelemwise(ps.log1mexp) |
| 76 | +log1p = _as_xelemwise(ps.log1p) |
| 77 | +log2 = _as_xelemwise(ps.log2) |
| 78 | +less = lt = _as_xelemwise(ps.lt) |
| 79 | +mod = _as_xelemwise(ps.mod) |
| 80 | +multiply = mul = _as_xelemwise(ps.mul) |
| 81 | +negative = neg = _as_xelemwise(ps.neg) |
| 82 | +not_equal = neq = _as_xelemwise(ps.neq) |
| 83 | +logical_or = bitwise_or = or_ = _as_xelemwise(ps.or_) |
| 84 | +owens_t = _as_xelemwise(ps.owens_t) |
| 85 | +polygamma = _as_xelemwise(ps.polygamma) |
| 86 | +power = pow = _as_xelemwise(ps.pow) |
| 87 | +psi = _as_xelemwise(ps.psi) |
| 88 | +rad2deg = _as_xelemwise(ps.rad2deg) |
| 89 | +real = _as_xelemwise(ps.real) |
| 90 | +reciprocal = _as_xelemwise(ps.reciprocal) |
| 91 | +round = _as_xelemwise(ps.round_half_to_even) |
| 92 | +maximum = _as_xelemwise(ps.scalar_maximum) |
| 93 | +minimum = _as_xelemwise(ps.scalar_minimum) |
| 94 | +second = _as_xelemwise(ps.second) |
| 95 | +sigmoid = _as_xelemwise(ps.sigmoid) |
| 96 | +sign = _as_xelemwise(ps.sign) |
| 97 | +sin = _as_xelemwise(ps.sin) |
| 98 | +sinh = _as_xelemwise(ps.sinh) |
| 99 | +softplus = _as_xelemwise(ps.softplus) |
| 100 | +square = sqr = _as_xelemwise(ps.sqr) |
| 101 | +sqrt = _as_xelemwise(ps.sqrt) |
| 102 | +subtract = sub = _as_xelemwise(ps.sub) |
| 103 | +where = switch = _as_xelemwise(ps.switch) |
| 104 | +tan = _as_xelemwise(ps.tan) |
| 105 | +tanh = _as_xelemwise(ps.tanh) |
| 106 | +tri_gamma = _as_xelemwise(ps.tri_gamma) |
| 107 | +true_divide = true_div = _as_xelemwise(ps.true_div) |
| 108 | +trunc = _as_xelemwise(ps.trunc) |
| 109 | +logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor) |
0 commit comments