Skip to content

Commit 958f98b

Browse files
Hardcode84Diptorup Deb
authored andcommitted
pairwise_distance f64 emulation
1 parent 9bb68d2 commit 958f98b

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88

9-
@nb.kernel
9+
@nb.kernel(gpu_fp64_truncate="auto")
1010
def _pairwise_distance_kernel(X1, X2, D):
1111
i = nb.get_global_id(0)
1212

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_n.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88

9-
@nb.njit(parallel=True, fastmath=True)
9+
@nb.njit(parallel=True, fastmath=True, gpu_fp64_truncate="auto")
1010
def _pairwise_distance(X1, X2, D):
1111
x1 = np.sum(np.square(X1), axis=1)
1212
x2 = np.sum(np.square(X2), axis=1)

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_p.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99

10-
@nb.njit(parallel=True, fastmath=True)
10+
@nb.njit(parallel=True, fastmath=True, gpu_fp64_truncate="auto")
1111
def _pairwise_distance(X1, X2, D):
1212
"""Naïve pairwise distance impl - take an array representing M points in N
1313
dimensions, and return the M x M matrix of Euclidean distances

0 commit comments

Comments
 (0)