Skip to content

Commit 7456b8e

Browse files
committed
Add single support for numba_dpex_p
1 parent 71d3c6b commit 7456b8e

File tree

6 files changed

+57
-15
lines changed

6 files changed

+57
-15
lines changed

dpbench/benchmarks/black_scholes/black_scholes_numba_dpex_p.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,15 @@
1010

1111
@dpjit
1212
def black_scholes(nopt, price, strike, t, rate, volatility, call, put):
13+
dtype = price.dtype
1314
mr = -rate
14-
sig_sig_two = volatility * volatility * 2
15+
sig_sig_two = volatility * volatility * dtype.type(2)
16+
17+
# TODO: get rid of it once prange supports dtype
18+
# https://github.com/IntelPython/numba-dpex/issues/1063
19+
float025 = dtype.type(0.25)
20+
float1 = dtype.type(1.0)
21+
float05 = dtype.type(0.5)
1522

1623
for i in prange(nopt):
1724
P = price[i]
@@ -22,14 +29,14 @@ def black_scholes(nopt, price, strike, t, rate, volatility, call, put):
2229
b = T * mr
2330

2431
z = T * sig_sig_two
25-
c = 0.25 * z
26-
y = 1.0 / sqrt(z)
32+
c = float025 * z
33+
y = float1 / sqrt(z)
2734

2835
w1 = (a - b + c) * y
2936
w2 = (a - b - c) * y
3037

31-
d1 = 0.5 + 0.5 * erf(w1)
32-
d2 = 0.5 + 0.5 * erf(w2)
38+
d1 = float05 + float05 * erf(w1)
39+
d2 = float05 + float05 * erf(w2)
3340

3441
Se = exp(b) * S
3542

dpbench/benchmarks/dbscan/dbscan_numba_dpex_p.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ def _queue_empty(head, tail):
5353
def get_neighborhood(n, dim, data, eps, ind_lst, sz_lst):
5454
block_size = 1
5555
nblocks = n // block_size + int(n % block_size > 0)
56+
57+
# TODO: get rid of it once prange supports dtype
58+
# https://github.com/IntelPython/numba-dpex/issues/1063
59+
float0 = data.dtype.type(0.0)
60+
5661
for i in nb.prange(nblocks):
5762
start = i * block_size
5863
stop = n if i + 1 == nblocks else start + block_size
@@ -64,7 +69,7 @@ def get_neighborhood(n, dim, data, eps, ind_lst, sz_lst):
6469
i2 = n if ii + 1 == nblocks1 else i1 + block_size1
6570
for j in range(start, stop):
6671
for k in range(i1, i2):
67-
dist = 0.0
72+
dist = float0
6873
for m in range(dim):
6974
diff = data[k * dim + m] - data[j * dim + m]
7075
dist += diff * diff

dpbench/benchmarks/kmeans/kmeans_numba_dpex_p.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
# determine the euclidean distance from the cluster center to each point
1111
@dpex.dpjit
1212
def groupByCluster(arrayP, arrayPcluster, arrayC, num_points, num_centroids):
13+
# TODO: get rid of it once prange supports dtype
14+
# https://github.com/IntelPython/numba-dpex/issues/1063
15+
float1 = arrayC.dtype.type(-1)
16+
1317
# parallel for loop
1418
for i0 in nb.prange(num_points):
15-
minor_distance = -1
19+
minor_distance = float1
1620
for i1 in range(num_centroids):
1721
dx = arrayP[i0, 0] - arrayC[i1, 0]
1822
dy = arrayP[i0, 1] - arrayC[i1, 1]
@@ -52,6 +56,13 @@ def updateCentroids(arrayC, arrayCsum, arrayCnumpoint, num_centroids):
5256
arrayC[i, 1] = arrayCsum[i, 1] / arrayCnumpoint[i]
5357

5458

59+
@dpex.dpjit
60+
def updateCentroids32(arrayC, arrayCsum, arrayCnumpoint, num_centroids):
61+
for i in nb.prange(num_centroids):
62+
arrayC[i, 0] = arrayCsum[i, 0] / nb.float32(arrayCnumpoint[i])
63+
arrayC[i, 1] = arrayCsum[i, 1] / nb.float32(arrayCnumpoint[i])
64+
65+
5566
@dpex.dpjit
5667
def copy_arrayC(arrayC, arrayP, num_centroids):
5768
for i in nb.prange(num_centroids):
@@ -85,7 +96,12 @@ def kmeans_numba(
8596
arrayP, arrayPcluster, arrayCsum, arrayCnumpoint
8697
)
8798

88-
updateCentroids(arrayC, arrayCsum, arrayCnumpoint, num_centroids)
99+
# TODO: get rid of it once prange supports dtype
100+
# https://github.com/IntelPython/numba-dpex/issues/1063
101+
if isinstance(arrayC.dtype.type(0), np.float32):
102+
updateCentroids32(arrayC, arrayCsum, arrayCnumpoint, num_centroids)
103+
else:
104+
updateCentroids(arrayC, arrayCsum, arrayCnumpoint, num_centroids)
89105

90106
return arrayC, arrayCsum, arrayCnumpoint
91107

dpbench/benchmarks/knn/knn_numba_dpex_p.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@ def knn( # noqa: C901: TODO: can we simplify logic?
2020
votes_to_classes,
2121
data_dim,
2222
):
23+
# TODO: get rid of it once prange supports dtype
24+
# https://github.com/IntelPython/numba-dpex/issues/1063
25+
float0 = x_train.dtype.type(0.0)
26+
2327
for i in nb.prange(test_size):
2428
queue_neighbors = np.empty(shape=(k, 2))
2529

2630
for j in range(k):
2731
x1 = x_train[j]
2832
x2 = x_test[i]
2933

30-
distance = 0.0
34+
distance = float0
3135
for jj in range(data_dim):
3236
diff = x1[jj] - x2[jj]
3337
distance += diff * diff
@@ -54,7 +58,7 @@ def knn( # noqa: C901: TODO: can we simplify logic?
5458
x1 = x_train[j]
5559
x2 = x_test[i]
5660

57-
distance = 0.0
61+
distance = float0
5862
for jj in range(data_dim):
5963
diff = x1[jj] - x2[jj]
6064
distance += diff * diff
@@ -84,7 +88,7 @@ def knn( # noqa: C901: TODO: can we simplify logic?
8488
v_to_c_i[int(queue_neighbors[j, 1])] += 1
8589

8690
max_ind = 0
87-
max_value = 0
91+
max_value = float0
8892

8993
for j in range(classes_num):
9094
if v_to_c_i[j] > max_value:

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_dpex_p.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@ def pairwise_distance(X1, X2, D):
2222
X2_rows = X2.shape[0]
2323
X1_cols = X1.shape[1]
2424

25+
# TODO: get rid of it once prange supports dtype
26+
# https://github.com/IntelPython/numba-dpex/issues/1063
27+
float0 = X1.dtype.type(0.0)
28+
2529
# Outermost parallel loop over the matrix X1
2630
for i in nb.prange(X1_rows):
2731
# Loop over the matrix X2
2832
for j in range(X2_rows):
29-
d = 0.0
33+
d = float0
3034
# Compute exclidean distance
3135
for k in range(X1_cols):
3236
tmp = X1[i, k] - X2[j, k]

dpbench/benchmarks/rambo/rambo_numba_dpex_p.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99

1010
@dpjit
1111
def rambo(nevts, nout, C1, F1, Q1, output):
12+
# TODO: get rid of it once prange supports dtype
13+
# https://github.com/IntelPython/numba-dpex/issues/1063
14+
float1 = C1.dtype.type(1.0)
15+
float2 = C1.dtype.type(2.0)
16+
floatPi = C1.dtype.type(np.pi)
17+
1218
for i in nb.prange(nevts):
1319
for j in range(nout):
14-
C = 2.0 * C1[i, j] - 1.0
15-
S = np.sqrt(1 - np.square(C))
16-
F = 2.0 * np.pi * F1[i, j]
20+
C = float2 * C1[i, j] - float1
21+
S = np.sqrt(float1 - np.square(C))
22+
F = float2 * floatPi * F1[i, j]
1723
Q = -np.log(Q1[i, j])
1824

1925
output[i, j, 0] = Q

0 commit comments

Comments
 (0)