11import math
2+ import numpy as np
3+
4+
5+ def _get_math_module (x ):
6+ """Return the appropriate math module (math or numpy) based on input type."""
7+ if isinstance (x , np .ndarray ):
8+ return np
9+ else :
10+ return math
211
312
413class DNumber :
@@ -93,11 +102,12 @@ def __pow__(self, power):
93102 if isinstance (power , DNumber ):
94103 new_value = self .value ** power .value
95104 new_derivatives = {}
105+ m = _get_math_module (self .value )
96106 for var in set (self .derivatives .keys ()).union (power .derivatives .keys ()):
97107 deriv1 = self .derivatives .get (var , 0 )
98108 deriv2 = power .derivatives .get (var , 0 )
99109 new_derivatives [var ] = new_value * (
100- deriv1 * power .value / self .value + deriv2 * math .log (self .value )
110+ deriv1 * power .value / self .value + deriv2 * m .log (self .value )
101111 )
102112 return DNumber (new_value , new_derivatives )
103113 else :
@@ -113,8 +123,9 @@ def __rpow__(self, base):
113123 return base .__pow__ (self )
114124 else :
115125 new_value = base ** self .value
126+ m = _get_math_module (self .value )
116127 new_derivatives = {
117- var : deriv * new_value * math .log (base )
128+ var : deriv * new_value * m .log (base )
118129 for var , deriv in self .derivatives .items ()
119130 }
120131 return DNumber (new_value , new_derivatives )
@@ -133,72 +144,79 @@ def __repr__(self):
133144def sin (x ):
134145 """Sine function that works with both floats and DNumber objects."""
135146 if isinstance (x , DNumber ):
136- new_value = math . sin (x .value )
147+ new_value = sin (x .value )
137148 new_derivatives = {
138- var : deriv * math . cos (x .value ) for var , deriv in x .derivatives .items ()
149+ var : deriv * cos (x .value ) for var , deriv in x .derivatives .items ()
139150 }
140151 return DNumber (new_value , new_derivatives )
141152 else :
142- return math .sin (x )
153+ m = _get_math_module (x )
154+ return m .sin (x )
143155
144156
145157def cos (x ):
146158 """Cosine function that works with both floats and DNumber objects."""
147159 if isinstance (x , DNumber ):
148- new_value = math . cos (x .value )
160+ new_value = cos (x .value )
149161 new_derivatives = {
150- var : - deriv * math . sin (x .value ) for var , deriv in x .derivatives .items ()
162+ var : - deriv * sin (x .value ) for var , deriv in x .derivatives .items ()
151163 }
152164 return DNumber (new_value , new_derivatives )
153165 else :
154- return math .cos (x )
166+ m = _get_math_module (x )
167+ return m .cos (x )
155168
156169
157170def tan (x ):
158171 """Tangent function that works with both floats and DNumber objects."""
159172 if isinstance (x , DNumber ):
160- new_value = math . tan (x .value )
161- sec_squared = 1 / (math . cos (x .value ) ** 2 )
173+ new_value = tan (x .value )
174+ sec_squared = 1 / (cos (x .value ) ** 2 )
162175 new_derivatives = {
163176 var : deriv * sec_squared for var , deriv in x .derivatives .items ()
164177 }
165178 return DNumber (new_value , new_derivatives )
166179 else :
167- return math .tan (x )
180+ m = _get_math_module (x )
181+ return m .tan (x )
168182
169183
170184def exp (x ):
171185 """Exponential function that works with both floats and DNumber objects."""
186+ from rich import print
172187 if isinstance (x , DNumber ):
173- new_value = math . exp (x .value )
188+ new_value = exp (x .value )
174189 new_derivatives = {
175190 var : deriv * new_value for var , deriv in x .derivatives .items ()
176191 }
177192 return DNumber (new_value , new_derivatives )
178193 else :
179- return math .exp (x )
194+ m = _get_math_module (x )
195+ return m .exp (x )
180196
181197
182198def log (x ):
183199 """Natural logarithm function that works with both floats and DNumber objects."""
184200 if isinstance (x , DNumber ):
185- new_value = math . log (x .value )
201+ new_value = log (x .value )
186202 new_derivatives = {var : deriv / x .value for var , deriv in x .derivatives .items ()}
187203 return DNumber (new_value , new_derivatives )
188204 else :
189- return math .log (x )
205+ m = _get_math_module (x )
206+ return m .log (x )
190207
191208
192209def sqrt (x ):
193210 """Square root function that works with both floats and DNumber objects."""
194211 if isinstance (x , DNumber ):
195- new_value = math . sqrt (x .value )
212+ new_value = sqrt (x .value )
196213 new_derivatives = {
197214 var : deriv / (2 * new_value ) for var , deriv in x .derivatives .items ()
198215 }
199216 return DNumber (new_value , new_derivatives )
200217 else :
201- return math .sqrt (x )
218+ m = _get_math_module (x )
219+ return m .sqrt (x )
202220
203221
204222def dabs (x ):
@@ -215,77 +233,83 @@ def dabs(x):
215233def sinh (x ):
216234 """Hyperbolic sine function that works with both floats and DNumber objects."""
217235 if isinstance (x , DNumber ):
218- new_value = math . sinh (x .value )
236+ new_value = sinh (x .value )
219237 new_derivatives = {
220- var : deriv * math . cosh (x .value ) for var , deriv in x .derivatives .items ()
238+ var : deriv * cosh (x .value ) for var , deriv in x .derivatives .items ()
221239 }
222240 return DNumber (new_value , new_derivatives )
223241 else :
224- return math .sinh (x )
242+ m = _get_math_module (x )
243+ return m .sinh (x )
225244
226245
227246def cosh (x ):
228247 """Hyperbolic cosine function that works with both floats and DNumber objects."""
229248 if isinstance (x , DNumber ):
230- new_value = math . cosh (x .value )
249+ new_value = cosh (x .value )
231250 new_derivatives = {
232- var : deriv * math . sinh (x .value ) for var , deriv in x .derivatives .items ()
251+ var : deriv * sinh (x .value ) for var , deriv in x .derivatives .items ()
233252 }
234253 return DNumber (new_value , new_derivatives )
235254 else :
236- return math .cosh (x )
255+ m = _get_math_module (x )
256+ return m .cosh (x )
237257
238258
239259def tanh (x ):
240260 """Hyperbolic tangent function that works with both floats and DNumber objects."""
241261 if isinstance (x , DNumber ):
242- new_value = math . tanh (x .value )
262+ new_value = tanh (x .value )
243263 sech_squared = 1 - new_value ** 2
244264 new_derivatives = {
245265 var : deriv * sech_squared for var , deriv in x .derivatives .items ()
246266 }
247267 return DNumber (new_value , new_derivatives )
248268 else :
249- return math .tanh (x )
269+ m = _get_math_module (x )
270+ return m .tanh (x )
250271
251272
252273def asin (x ):
253274 """Arcsine function that works with both floats and DNumber objects."""
254275 if isinstance (x , DNumber ):
255- new_value = math . asin (x .value )
256- derivative_factor = 1 / math . sqrt (1 - x .value ** 2 )
276+ new_value = asin (x .value )
277+ derivative_factor = 1 / sqrt (1 - x .value ** 2 )
257278 new_derivatives = {
258279 var : deriv * derivative_factor for var , deriv in x .derivatives .items ()
259280 }
260281 return DNumber (new_value , new_derivatives )
261282 else :
262- return math .asin (x )
283+ m = _get_math_module (x )
284+ return np .arcsin (x ) if m is np else math .asin (x )
263285
264286
265287def acos (x ):
266288 """Arccosine function that works with both floats and DNumber objects."""
267289 if isinstance (x , DNumber ):
268- new_value = math . acos (x .value )
269- derivative_factor = - 1 / math . sqrt (1 - x .value ** 2 )
290+ new_value = acos (x .value )
291+ derivative_factor = - 1 / sqrt (1 - x .value ** 2 )
270292 new_derivatives = {
271293 var : deriv * derivative_factor for var , deriv in x .derivatives .items ()
272294 }
273295 return DNumber (new_value , new_derivatives )
274296 else :
275- return math .acos (x )
297+ m = _get_math_module (x )
298+ return np .arccos (x ) if m is np else math .acos (x )
276299
277300
278301def atan (x ):
279302 """Arctangent function that works with both floats and DNumber objects."""
280303 if isinstance (x , DNumber ):
281- new_value = math . atan (x .value )
304+ new_value = atan (x .value )
282305 derivative_factor = 1 / (1 + x .value ** 2 )
283306 new_derivatives = {
284307 var : deriv * derivative_factor for var , deriv in x .derivatives .items ()
285308 }
286309 return DNumber (new_value , new_derivatives )
287310 else :
288- return math .atan (x )
311+ m = _get_math_module (x )
312+ return np .arctan (x ) if m is np else math .atan (x )
289313
290314
291315def dmax (x , y ):
@@ -325,48 +349,52 @@ def dmin(x, y):
325349def log10 (x ):
326350 """Base-10 logarithm function that works with both floats and DNumber objects."""
327351 if isinstance (x , DNumber ):
328- new_value = math . log10 (x .value )
352+ new_value = log10 (x .value )
329353 new_derivatives = {
330- var : deriv / (x .value * math . log (10 ))
354+ var : deriv / (x .value * log (10 ))
331355 for var , deriv in x .derivatives .items ()
332356 }
333357 return DNumber (new_value , new_derivatives )
334358 else :
335- return math .log10 (x )
359+ m = _get_math_module (x )
360+ return m .log10 (x )
336361
337362
338363def log2 (x ):
339364 """Base-2 logarithm function that works with both floats and DNumber objects."""
340365 if isinstance (x , DNumber ):
341- new_value = math . log2 (x .value )
366+ new_value = log2 (x .value )
342367 new_derivatives = {
343- var : deriv / (x .value * math . log (2 )) for var , deriv in x .derivatives .items ()
368+ var : deriv / (x .value * log (2 )) for var , deriv in x .derivatives .items ()
344369 }
345370 return DNumber (new_value , new_derivatives )
346371 else :
347- return math .log2 (x )
372+ m = _get_math_module (x )
373+ return m .log2 (x )
348374
349375
350376def floor (x ):
351377 """Floor function that works with both floats and DNumber objects."""
352378 if isinstance (x , DNumber ):
353- new_value = math . floor (x .value )
379+ new_value = floor (x .value )
354380 # Derivative of floor is 0 everywhere except at integer points (where it's undefined)
355381 new_derivatives = {var : 0.0 for var in x .derivatives .keys ()}
356382 return DNumber (new_value , new_derivatives )
357383 else :
358- return math .floor (x )
384+ m = _get_math_module (x )
385+ return m .floor (x )
359386
360387
361388def ceil (x ):
362389 """Ceiling function that works with both floats and DNumber objects."""
363390 if isinstance (x , DNumber ):
364- new_value = math . ceil (x .value )
391+ new_value = ceil (x .value )
365392 # Derivative of ceil is 0 everywhere except at integer points (where it's undefined)
366393 new_derivatives = {var : 0.0 for var in x .derivatives .keys ()}
367394 return DNumber (new_value , new_derivatives )
368395 else :
369- return math .ceil (x )
396+ m = _get_math_module (x )
397+ return m .ceil (x )
370398
371399
372400def pow (x , y ):
0 commit comments