Skip to content

Commit ed73d4c

Browse files
committed
Add numpy type hints example
1 parent 7436c34 commit ed73d4c

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

source-code/typing/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ errors.
88
Type checking can be done using [mypy](http://mypy-lang.org/index.html).
99

1010
## What is it?
11+
1. `mypy.ini`: mypy configuration file.
1112
1. `correct.py`: code that has type annotations, and no type errors.
1213
1. `incorrect_01.py`: code that has type annotations, and passes a string
1314
to a function that expects an `int`.
@@ -33,3 +34,5 @@ Type checking can be done using [mypy](http://mypy-lang.org/index.html).
3334
1. `typed_duck_typing_false_positive.py`: example code illustrating
3435
duck typing using type hints for which mypy 0.910 generates a
3536
false positive.
37+
1. `numpy_typing.py`: illustration of a script using both numpy and
38+
matplotlib with type hints.

source-code/typing/mypy.ini

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[mypy]
2+
strict = True
3+
plugins = numpy.typing.mypy_plugin
4+
5+
[mypy-matplotlib.*]
6+
ignore_missing_imports = True
7+
ignore_errors = True

source-code/typing/numpy_typing.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/usr/bin/env python
2+
3+
import argparse
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import numpy.typing as npt
7+
8+
9+
def generate_input(a: float | np.float32 | np.float64,
10+
b: float | np.float32 | np.float64,
11+
n: int | np.int32 | np.int64) -> npt.NDArray[np.float64]:
12+
'''Generate x-values in given range
13+
14+
Parameters
15+
----------
16+
a: float
17+
lower bound
18+
b: float
19+
upper bound
20+
n: int
21+
number of points to generate
22+
23+
Returns
24+
-------
25+
np.NDarray
26+
generated x-values
27+
'''
28+
return np.linspace(a, b, n)
29+
30+
31+
def gaussian(x: npt.NDArray[np.float64], mu: float, sigma: float) -> npt.NDArray[np.float64]:
32+
y: npt.NDArray[np.float64] = np.exp(-0.5*(x - mu)**2/sigma)/np.sqrt(2.0*np.pi*sigma)
33+
return y
34+
35+
36+
def plot_function(x: npt.NDArray[np.float64], y: npt.NDArray[np.float64]) -> None:
37+
plt.plot(x, y)
38+
plt.show()
39+
return
40+
41+
42+
if __name__ == '__main__':
43+
arg_parser = argparse.ArgumentParser(description='numpy type checking')
44+
arg_parser.add_argument('--mu', type=float, default=0.0,
45+
help='mean value')
46+
arg_parser.add_argument('--sigma', type=float, default=1.0,
47+
help='standard deviation')
48+
arg_parser.add_argument('--n', type=int, default=10,
49+
help='number of points')
50+
arg_parser.add_argument('--plot', action='set_true',
51+
help='show plot')
52+
options = arg_parser.parse_args()
53+
x = generate_input(-3.0, 3.0, options.n)
54+
y = gaussian(x, options.mu, options.sigma)
55+
if options.plot:
56+
plot_function(x, y)
57+
else:
58+
print(y)

0 commit comments

Comments
 (0)