Skip to content

Commit 91ef2ef

Browse files
authored
Add cauchy link function support to the linear model assembler (#304)
* added atan expression * added atan for F# * fix lint * added cauchy link function * hotfix
1 parent 0c1ea76 commit 91ef2ef

38 files changed

+560
-12
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ recursive-include m2cgen VERSION.txt
33
recursive-include m2cgen linear_algebra.*
44
recursive-include m2cgen log1p.*
55
recursive-include m2cgen tanh.*
6+
recursive-include m2cgen atan.*
67
global-exclude *.py[cod]

m2cgen/assemblers/fallback_expressions.py

Lines changed: 116 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,16 @@ def tanh(expr):
4040
tanh_expr))
4141

4242

43-
def sqrt(expr, to_reuse=False):
43+
def sqrt(expr):
4444
return ast.PowExpr(
4545
base_expr=expr,
46-
exp_expr=ast.NumVal(0.5),
47-
to_reuse=to_reuse)
46+
exp_expr=ast.NumVal(0.5))
4847

4948

50-
def exp(expr, to_reuse=False):
49+
def exp(expr):
5150
return ast.PowExpr(
5251
base_expr=ast.NumVal(math.e),
53-
exp_expr=expr,
54-
to_reuse=to_reuse)
52+
exp_expr=expr)
5553

5654

5755
def log1p(expr):
@@ -66,6 +64,118 @@ def log1p(expr):
6664
utils.div(utils.mul(expr, ast.LogExpr(expr1p)), expr1pm1))
6765

6866

67+
def atan(expr):
68+
expr = ast.IdExpr(expr, to_reuse=True)
69+
expr_abs = ast.AbsExpr(expr, to_reuse=True)
70+
71+
expr_reduced = ast.IdExpr(
72+
ast.IfExpr(
73+
utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)),
74+
utils.div(ast.NumVal(1.0), expr_abs),
75+
ast.IfExpr(
76+
utils.gt(expr_abs, ast.NumVal(0.66)),
77+
utils.div(
78+
utils.sub(expr_abs, ast.NumVal(1.0)),
79+
utils.add(expr_abs, ast.NumVal(1.0))),
80+
expr_abs)),
81+
to_reuse=True)
82+
83+
P0 = ast.NumVal(-8.750608600031904122785e-01)
84+
P1 = ast.NumVal(1.615753718733365076637e+01)
85+
P2 = ast.NumVal(7.500855792314704667340e+01)
86+
P3 = ast.NumVal(1.228866684490136173410e+02)
87+
P4 = ast.NumVal(6.485021904942025371773e+01)
88+
Q0 = ast.NumVal(2.485846490142306297962e+01)
89+
Q1 = ast.NumVal(1.650270098316988542046e+02)
90+
Q2 = ast.NumVal(4.328810604912902668951e+02)
91+
Q3 = ast.NumVal(4.853903996359136964868e+02)
92+
Q4 = ast.NumVal(1.945506571482613964425e+02)
93+
expr2 = utils.mul(expr_reduced, expr_reduced, to_reuse=True)
94+
z = utils.mul(
95+
expr2,
96+
utils.div(
97+
utils.sub(
98+
utils.mul(
99+
expr2,
100+
utils.sub(
101+
utils.mul(
102+
expr2,
103+
utils.sub(
104+
utils.mul(
105+
expr2,
106+
utils.sub(
107+
utils.mul(
108+
expr2,
109+
P0
110+
),
111+
P1
112+
)
113+
),
114+
P2
115+
)
116+
),
117+
P3
118+
)
119+
),
120+
P4
121+
),
122+
utils.add(
123+
Q4,
124+
utils.mul(
125+
expr2,
126+
utils.add(
127+
Q3,
128+
utils.mul(
129+
expr2,
130+
utils.add(
131+
Q2,
132+
utils.mul(
133+
expr2,
134+
utils.add(
135+
Q1,
136+
utils.mul(
137+
expr2,
138+
utils.add(
139+
Q0,
140+
expr2
141+
)
142+
)
143+
)
144+
)
145+
)
146+
)
147+
)
148+
)
149+
)
150+
)
151+
)
152+
z = utils.add(utils.mul(expr_reduced, z), expr_reduced)
153+
154+
ret = utils.mul(
155+
z,
156+
ast.IfExpr(
157+
utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)),
158+
ast.NumVal(-1.0),
159+
ast.NumVal(1.0)))
160+
ret = utils.add(
161+
ret,
162+
ast.IfExpr(
163+
utils.lte(expr_abs, ast.NumVal(0.66)),
164+
ast.NumVal(0.0),
165+
ast.IfExpr(
166+
utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)),
167+
ast.NumVal(1.570796326794896680463661649),
168+
ast.NumVal(0.7853981633974483402318308245))))
169+
ret = utils.mul(
170+
ret,
171+
ast.IfExpr(
172+
utils.lt(expr, ast.NumVal(0.0)),
173+
ast.NumVal(-1.0),
174+
ast.NumVal(1.0)))
175+
176+
return ret
177+
178+
69179
def sigmoid(expr, to_reuse=False):
70180
neg_expr = ast.BinNumExpr(ast.NumVal(0.0), expr, ast.BinNumOpType.SUB)
71181
exp_expr = ast.ExpExpr(neg_expr)

m2cgen/assemblers/linear.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import numpy as np
24

35
from m2cgen import ast
@@ -149,6 +151,13 @@ def _negativebinomial_inversed(self, ast_to_transform):
149151
ast.NumVal(-1.0),
150152
utils.mul(ast.NumVal(alpha), res) if alpha != 1.0 else res)
151153

154+
def _cauchy_inversed(self, ast_to_transform):
155+
return utils.add(
156+
ast.NumVal(0.5),
157+
utils.div(
158+
ast.AtanExpr(ast_to_transform),
159+
ast.NumVal(math.pi)))
160+
152161
def _get_power(self):
153162
raise NotImplementedError
154163

@@ -172,7 +181,8 @@ def _get_supported_inversed_funs(self):
172181
"log": self._log_inversed,
173182
"cloglog": self._cloglog_inversed,
174183
"negativebinomial": self._negativebinomial_inversed,
175-
"nbinom": self._negativebinomial_inversed
184+
"nbinom": self._negativebinomial_inversed,
185+
"cauchy": self._cauchy_inversed
176186
}
177187

178188
def _get_power(self):

m2cgen/ast.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,23 @@ def __hash__(self):
8686
return hash(self.expr)
8787

8888

89+
class AtanExpr(NumExpr):
90+
def __init__(self, expr, to_reuse=False):
91+
assert expr.output_size == 1, "Only scalars are supported"
92+
93+
self.expr = expr
94+
self.to_reuse = to_reuse
95+
96+
def __str__(self):
97+
return f"AtanExpr({self.expr},to_reuse={self.to_reuse})"
98+
99+
def __eq__(self, other):
100+
return type(other) is AtanExpr and self.expr == other.expr
101+
102+
def __hash__(self):
103+
return hash(self.expr)
104+
105+
89106
class ExpExpr(NumExpr):
90107
def __init__(self, expr, to_reuse=False):
91108
assert expr.output_size == 1, "Only scalars are supported"
@@ -370,7 +387,8 @@ def __hash__(self):
370387
(PowExpr, lambda e: [e.base_expr, e.exp_expr]),
371388
(VectorVal, lambda e: e.exprs),
372389
(IfExpr, lambda e: [e.test, e.body, e.orelse]),
373-
((AbsExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr, SqrtExpr, TanhExpr),
390+
((AbsExpr, AtanExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr,
391+
SqrtExpr, TanhExpr),
374392
lambda e: [e.expr]),
375393
]
376394

m2cgen/interpreters/c/interpreter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class CInterpreter(ImperativeToCodeInterpreter,
1818
}
1919

2020
abs_function_name = "fabs"
21+
atan_function_name = "atan"
2122
exponent_function_name = "exp"
2223
logarithm_function_name = "log"
2324
log1p_function_name = "log1p"

m2cgen/interpreters/c_sharp/interpreter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class CSharpInterpreter(ImperativeToCodeInterpreter,
1919
}
2020

2121
abs_function_name = "Abs"
22+
atan_function_name = "Atan"
2223
exponent_function_name = "Exp"
2324
logarithm_function_name = "Log"
2425
log1p_function_name = "Log1p"

m2cgen/interpreters/dart/interpreter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class DartInterpreter(ImperativeToCodeInterpreter,
2222
bin_depth_threshold = 465
2323

2424
abs_function_name = "abs"
25+
atan_function_name = "atan"
2526
exponent_function_name = "exp"
2627
logarithm_function_name = "log"
2728
log1p_function_name = "log1p"

m2cgen/interpreters/f_sharp/interpreter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class FSharpInterpreter(FunctionalToCodeInterpreter,
2626
}
2727

2828
abs_function_name = "abs"
29+
atan_function_name = "atan"
2930
exponent_function_name = "exp"
3031
logarithm_function_name = "log"
3132
log1p_function_name = "log1p"

m2cgen/interpreters/go/interpreter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class GoInterpreter(ImperativeToCodeInterpreter,
1717
}
1818

1919
abs_function_name = "math.Abs"
20+
atan_function_name = "math.Atan"
2021
exponent_function_name = "math.Exp"
2122
logarithm_function_name = "math.Log"
2223
log1p_function_name = "math.Log1p"

m2cgen/interpreters/haskell/interpreter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class HaskellInterpreter(FunctionalToCodeInterpreter,
1717
}
1818

1919
abs_function_name = "abs"
20+
atan_function_name = "atan"
2021
exponent_function_name = "exp"
2122
logarithm_function_name = "log"
2223
log1p_function_name = "log1p"

0 commit comments

Comments
 (0)