|
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | 5 |
|
6 | | -def initialize(npoints, niters, seed, ndims, ncentroids): |
| 6 | +def initialize(npoints, niters, seed, ndims, ncentroids, types_dict): |
7 | 7 | import numpy as np |
8 | 8 | import numpy.random as default_rng |
9 | 9 |
|
10 | | - dtype = np.float64 |
| 10 | + f_dtype = types_dict["float"] |
| 11 | + i_dtype = types_dict["int"] |
11 | 12 | XL = 1.0 |
12 | 13 | XH = 5.0 |
13 | 14 |
|
14 | 15 | default_rng.seed(seed) |
15 | 16 |
|
16 | | - arrayP = default_rng.uniform(XL, XH, (npoints, ndims)).astype(dtype) |
17 | | - arrayPclusters = np.ones(npoints, dtype=np.int64) |
18 | | - arrayC = np.ones((ncentroids, ndims), dtype=dtype) |
19 | | - arrayCsum = np.ones((ncentroids, ndims), dtype=dtype) |
20 | | - arrayCnumpoint = np.ones(ncentroids, dtype=np.int64) |
| 17 | + arrayP = default_rng.uniform(XL, XH, (npoints, ndims)).astype(f_dtype) |
| 18 | + arrayPclusters = np.ones(npoints, dtype=i_dtype) |
| 19 | + arrayC = np.ones((ncentroids, ndims), dtype=f_dtype) |
| 20 | + arrayCsum = np.ones((ncentroids, ndims), dtype=f_dtype) |
| 21 | + arrayCnumpoint = np.ones(ncentroids, dtype=i_dtype) |
21 | 22 |
|
22 | 23 | return (arrayP, arrayPclusters, arrayC, arrayCsum, arrayCnumpoint) |
0 commit comments