Skip to content

Commit 8905f50

Browse files
npolina4ZzEeKkAa
authored andcommitted
Added dtype parameter in dpnp.sum() function call for avoid unnecessary copying.
1 parent 7f1a3c9 commit 8905f50

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

dpbench/benchmarks/gpairs/gpairs_dpnp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ def _gpairs_impl(x1, y1, z1, w1, x2, y2, z2, w2, rbins):
1212
+ np.square(z2 - z1[:, None])
1313
)
1414
return np.array(
15-
[np.outer(w1, w2)[dm <= rbins[k]].sum() for k in range(len(rbins))],
15+
[
16+
np.outer(w1, w2)[dm <= rbins[k]].sum(dtype=np.result_type(w1, w2))
17+
for k in range(len(rbins))
18+
],
1619
device=x1.device,
1720
)
1821

dpbench/benchmarks/l2_norm/l2_norm_dpnp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77

88
def l2_norm(a, d):
99
sq = np.square(a)
10-
sum = sq.sum(axis=1)
10+
sum = sq.sum(axis=1, dtype=sq.dtype)
1111
d[:] = np.sqrt(sum)

dpbench/benchmarks/pairwise_distance/pairwise_distance_dpnp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77

88
def pairwise_distance(X1, X2, D):
9-
x1 = np.sum(np.square(X1), axis=1)
10-
x2 = np.sum(np.square(X2), axis=1)
9+
x1 = np.sum(np.square(X1), axis=1, dtype=X1.dtype)
10+
x2 = np.sum(np.square(X2), axis=1, dtype=X2.dtype)
1111
np.dot(X1, X2.T, D)
1212
D *= -2
1313
x3 = x1.reshape(x1.size, 1)

0 commit comments

Comments
 (0)