Skip to content

Commit 915b291

Browse files
committed
Add fp32 support for numba_dpex_k.py
1 parent 97e19f5 commit 915b291

File tree

7 files changed

+31
-25
lines changed

7 files changed

+31
-25
lines changed

dpbench/benchmarks/black_scholes/black_scholes_numba_dpex_k.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44

55
from math import erf, exp, log, sqrt
66

7+
import dpnp as np
78
import numba_dpex as dpex
89

910

1011
@dpex.kernel
1112
def _black_scholes_kernel(nopt, price, strike, t, rate, volatility, call, put):
13+
dtype = price.dtype
1214
mr = -rate
13-
sig_sig_two = volatility * volatility * 2
15+
sig_sig_two = volatility * volatility * dtype.type(2)
1416

1517
i = dpex.get_global_id(0)
1618

@@ -22,14 +24,14 @@ def _black_scholes_kernel(nopt, price, strike, t, rate, volatility, call, put):
2224
b = T * mr
2325

2426
z = T * sig_sig_two
25-
c = 0.25 * z
26-
y = 1.0 / sqrt(z)
27+
c = dtype.type(0.25) * z
28+
y = dtype.type(1.0) / sqrt(z)
2729

2830
w1 = (a - b + c) * y
2931
w2 = (a - b - c) * y
3032

31-
d1 = 0.5 + 0.5 * erf(w1)
32-
d2 = 0.5 + 0.5 * erf(w2)
33+
d1 = dtype.type(0.5) + dtype.type(0.5) * erf(w1)
34+
d2 = dtype.type(0.5) + dtype.type(0.5) * erf(w2)
3335

3436
Se = exp(b) * S
3537

dpbench/benchmarks/dbscan/dbscan_numba_dpex_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_neighborhood(n, dim, data, eps, ind_lst, sz_lst, block_size, nblocks):
6363
i2 = n if ii + 1 == nblocks1 else i1 + block_size1
6464
for j in range(start, stop):
6565
for k in range(i1, i2):
66-
dist = 0.0
66+
dist = data.dtype.type(0.0)
6767
for m in range(dim):
6868
diff = data[k * dim + m] - data[j * dim + m]
6969
dist += diff * diff

dpbench/benchmarks/gpairs/gpairs_numba_dpex_k.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import math
6-
75
import numba_dpex as dpex
8-
import numpy as np
96

107
# This implementation is numba dpex kernel version with atomics.
118

@@ -27,6 +24,7 @@ def count_weighted_pairs_3d_intel_no_slm_ker(
2724
rbins_squared,
2825
result,
2926
):
27+
dtype = x0.dtype
3028
lid0 = dpex.get_local_id(0)
3129
gr0 = dpex.get_group_id(0)
3230

@@ -38,9 +36,9 @@ def count_weighted_pairs_3d_intel_no_slm_ker(
3836

3937
n_wi = 20
4038

41-
dsq_mat = dpex.private.array(shape=(20 * 20), dtype=np.float32)
42-
w0_vec = dpex.private.array(shape=(20), dtype=np.float32)
43-
w1_vec = dpex.private.array(shape=(20), dtype=np.float32)
39+
dsq_mat = dpex.private.array(shape=(20 * 20), dtype=dtype)
40+
w0_vec = dpex.private.array(shape=(20), dtype=dtype)
41+
w1_vec = dpex.private.array(shape=(20), dtype=dtype)
4442

4543
offset0 = gr0 * n_wi * lws0 + lid0
4644
offset1 = gr1 * n_wi * lws1 + lid1
@@ -80,7 +78,7 @@ def count_weighted_pairs_3d_intel_no_slm_ker(
8078

8179
# update slm_hist. Use work-item private buffer of 16 tfloat elements
8280
for k in range(0, slm_hist_size, private_hist_size):
83-
private_hist = dpex.private.array(shape=(16), dtype=np.float32)
81+
private_hist = dpex.private.array(shape=(16), dtype=dtype)
8482
for p in range(private_hist_size):
8583
private_hist[p] = 0.0
8684

@@ -95,7 +93,9 @@ def count_weighted_pairs_3d_intel_no_slm_ker(
9593
pk = k
9694
for p in range(private_hist_size):
9795
private_hist[p] += (
98-
pw if (pk < nbins and dsq <= rbins_squared[pk]) else 0.0
96+
pw
97+
if (pk < nbins and dsq <= rbins_squared[pk])
98+
else dtype.type(0.0)
9999
)
100100
pk += 1
101101

dpbench/benchmarks/kmeans/kmeans_numba_dpex_k.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
def groupByCluster(arrayP, arrayPcluster, arrayC, num_points, num_centroids):
1212
idx = dpex.get_global_id(0)
1313
# if idx < num_points: # why it was removed??
14-
minor_distance = -1
14+
dtype = arrayC.dtype
15+
minor_distance = dtype.type(-1)
1516
for i in range(num_centroids):
1617
dx = arrayP[idx, 0] - arrayC[i, 0]
1718
dy = arrayP[idx, 1] - arrayC[i, 1]
@@ -41,8 +42,9 @@ def calCentroidsSum2(arrayP, arrayPcluster, arrayCsum, arrayCnumpoint):
4142
@dpex.kernel
4243
def updateCentroids(arrayC, arrayCsum, arrayCnumpoint, num_centroids):
4344
i = dpex.get_global_id(0)
44-
arrayC[i, 0] = arrayCsum[i, 0] / arrayCnumpoint[i]
45-
arrayC[i, 1] = arrayCsum[i, 1] / arrayCnumpoint[i]
45+
dtype = arrayC.dtype
46+
arrayC[i, 0] = arrayCsum[i, 0] / dtype.type(arrayCnumpoint[i])
47+
arrayC[i, 1] = arrayCsum[i, 1] / dtype.type(arrayCnumpoint[i])
4648

4749

4850
@dpex.kernel

dpbench/benchmarks/knn/knn_numba_dpex_k.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@ def _knn_kernel( # noqa: C901: TODO: can we simplify logic?
2020
votes_to_classes_lst,
2121
data_dim,
2222
):
23+
dtype = train.dtype
2324
i = dpex.get_global_id(0)
2425
# here k has to be 5 in order to match with numpy
25-
queue_neighbors = dpex.private.array(shape=(5, 2), dtype=np.float64)
26+
queue_neighbors = dpex.private.array(shape=(5, 2), dtype=dtype)
2627

2728
for j in range(k):
2829
x1 = train[j]
2930
x2 = test[i]
3031

31-
distance = 0.0
32+
distance = dtype.type(0.0)
3233
for jj in range(data_dim):
3334
diff = x1[jj] - x2[jj]
3435
distance += diff * diff
@@ -55,7 +56,7 @@ def _knn_kernel( # noqa: C901: TODO: can we simplify logic?
5556
x1 = train[j]
5657
x2 = test[i]
5758

58-
distance = 0.0
59+
distance = dtype.type(0.0)
5960
for jj in range(data_dim):
6061
diff = x1[jj] - x2[jj]
6162
distance += diff * diff
@@ -83,7 +84,7 @@ def _knn_kernel( # noqa: C901: TODO: can we simplify logic?
8384
votes_to_classes[int(queue_neighbors[j, 1])] += 1
8485

8586
max_ind = 0
86-
max_value = 0
87+
max_value = dtype.type(0)
8788

8889
for j in range(classes_num):
8990
if votes_to_classes[j] > max_value:

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_dpex_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def _pairwise_distance_kernel(X1, X2, D):
1313
X2_rows = X2.shape[0]
1414
X1_cols = X1.shape[1]
1515
for j in range(X2_rows):
16-
d = 0.0
16+
d = X1.dtype.type(0.0)
1717
for k in range(X1_cols):
1818
tmp = X1[i, k] - X2[j, k]
1919
d += tmp * tmp

dpbench/benchmarks/rambo/rambo_numba_dpex_k.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99

1010
@dpex.kernel
1111
def _rambo(C1, F1, Q1, nout, output):
12+
dtype = C1.dtype
1213
i = dpex.get_global_id(0)
1314
for j in range(nout):
14-
C = 2.0 * C1[i, j] - 1.0
15-
S = sqrt(1 - C * C)
16-
F = 2.0 * pi * F1[i, j]
15+
C = dtype.type(2.0) * C1[i, j] - dtype.type(1.0)
16+
S = sqrt(dtype.type(1) - C * C)
17+
F = dtype.type(2.0 * pi) * F1[i, j]
1718
Q = -log(Q1[i, j])
1819

1920
output[i, j, 0] = Q

0 commit comments

Comments
 (0)