@@ -40,18 +40,16 @@ def tanh(expr):
4040 tanh_expr ))
4141
4242
43- def sqrt (expr , to_reuse = False ):
43+ def sqrt (expr ):
4444 return ast .PowExpr (
4545 base_expr = expr ,
46- exp_expr = ast .NumVal (0.5 ),
47- to_reuse = to_reuse )
46+ exp_expr = ast .NumVal (0.5 ))
4847
4948
50- def exp (expr , to_reuse = False ):
49+ def exp (expr ):
5150 return ast .PowExpr (
5251 base_expr = ast .NumVal (math .e ),
53- exp_expr = expr ,
54- to_reuse = to_reuse )
52+ exp_expr = expr )
5553
5654
5755def log1p (expr ):
@@ -66,6 +64,118 @@ def log1p(expr):
6664 utils .div (utils .mul (expr , ast .LogExpr (expr1p )), expr1pm1 ))
6765
6866
67+ def atan (expr ):
68+ expr = ast .IdExpr (expr , to_reuse = True )
69+ expr_abs = ast .AbsExpr (expr , to_reuse = True )
70+
71+ expr_reduced = ast .IdExpr (
72+ ast .IfExpr (
73+ utils .gt (expr_abs , ast .NumVal (2.4142135623730950488 )),
74+ utils .div (ast .NumVal (1.0 ), expr_abs ),
75+ ast .IfExpr (
76+ utils .gt (expr_abs , ast .NumVal (0.66 )),
77+ utils .div (
78+ utils .sub (expr_abs , ast .NumVal (1.0 )),
79+ utils .add (expr_abs , ast .NumVal (1.0 ))),
80+ expr_abs )),
81+ to_reuse = True )
82+
83+ P0 = ast .NumVal (- 8.750608600031904122785e-01 )
84+ P1 = ast .NumVal (1.615753718733365076637e+01 )
85+ P2 = ast .NumVal (7.500855792314704667340e+01 )
86+ P3 = ast .NumVal (1.228866684490136173410e+02 )
87+ P4 = ast .NumVal (6.485021904942025371773e+01 )
88+ Q0 = ast .NumVal (2.485846490142306297962e+01 )
89+ Q1 = ast .NumVal (1.650270098316988542046e+02 )
90+ Q2 = ast .NumVal (4.328810604912902668951e+02 )
91+ Q3 = ast .NumVal (4.853903996359136964868e+02 )
92+ Q4 = ast .NumVal (1.945506571482613964425e+02 )
93+ expr2 = utils .mul (expr_reduced , expr_reduced , to_reuse = True )
94+ z = utils .mul (
95+ expr2 ,
96+ utils .div (
97+ utils .sub (
98+ utils .mul (
99+ expr2 ,
100+ utils .sub (
101+ utils .mul (
102+ expr2 ,
103+ utils .sub (
104+ utils .mul (
105+ expr2 ,
106+ utils .sub (
107+ utils .mul (
108+ expr2 ,
109+ P0
110+ ),
111+ P1
112+ )
113+ ),
114+ P2
115+ )
116+ ),
117+ P3
118+ )
119+ ),
120+ P4
121+ ),
122+ utils .add (
123+ Q4 ,
124+ utils .mul (
125+ expr2 ,
126+ utils .add (
127+ Q3 ,
128+ utils .mul (
129+ expr2 ,
130+ utils .add (
131+ Q2 ,
132+ utils .mul (
133+ expr2 ,
134+ utils .add (
135+ Q1 ,
136+ utils .mul (
137+ expr2 ,
138+ utils .add (
139+ Q0 ,
140+ expr2
141+ )
142+ )
143+ )
144+ )
145+ )
146+ )
147+ )
148+ )
149+ )
150+ )
151+ )
152+ z = utils .add (utils .mul (expr_reduced , z ), expr_reduced )
153+
154+ ret = utils .mul (
155+ z ,
156+ ast .IfExpr (
157+ utils .gt (expr_abs , ast .NumVal (2.4142135623730950488 )),
158+ ast .NumVal (- 1.0 ),
159+ ast .NumVal (1.0 )))
160+ ret = utils .add (
161+ ret ,
162+ ast .IfExpr (
163+ utils .lte (expr_abs , ast .NumVal (0.66 )),
164+ ast .NumVal (0.0 ),
165+ ast .IfExpr (
166+ utils .gt (expr_abs , ast .NumVal (2.4142135623730950488 )),
167+ ast .NumVal (1.570796326794896680463661649 ),
168+ ast .NumVal (0.7853981633974483402318308245 ))))
169+ ret = utils .mul (
170+ ret ,
171+ ast .IfExpr (
172+ utils .lt (expr , ast .NumVal (0.0 )),
173+ ast .NumVal (- 1.0 ),
174+ ast .NumVal (1.0 )))
175+
176+ return ret
177+
178+
69179def sigmoid (expr , to_reuse = False ):
70180 neg_expr = ast .BinNumExpr (ast .NumVal (0.0 ), expr , ast .BinNumOpType .SUB )
71181 exp_expr = ast .ExpExpr (neg_expr )
0 commit comments