|
| 1 | +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +import dpctl |
| 6 | +import dpnp |
| 7 | +import numba as nb |
| 8 | +import pytest |
| 9 | + |
| 10 | +from numba_dpex import dpjit |
| 11 | + |
| 12 | + |
| 13 | +def test_pairwise_distance(): |
| 14 | + @dpjit |
| 15 | + def pairwise_distance(X1, X2, D): |
| 16 | + """Naïve pairwise distance impl - take an array representing M points in N |
| 17 | + dimensions, and return the M x M matrix of Euclidean distances |
| 18 | +
|
| 19 | + Args: |
| 20 | + X1 : Set of points |
| 21 | + X2 : Set of points |
| 22 | + D : Outputted distance matrix |
| 23 | + """ |
| 24 | + # Size of inputs |
| 25 | + X1_rows = X1.shape[0] |
| 26 | + X2_rows = X2.shape[0] |
| 27 | + X1_cols = X1.shape[1] |
| 28 | + |
| 29 | + # TODO: get rid of it once prange supports dtype |
| 30 | + # https://github.com/IntelPython/numba-dpex/issues/1063 |
| 31 | + float0 = X1.dtype.type(0.0) |
| 32 | + |
| 33 | + # Outermost parallel loop over the matrix X1 |
| 34 | + for i in nb.prange(X1_rows): |
| 35 | + # Loop over the matrix X2 |
| 36 | + for j in range(X2_rows): |
| 37 | + d = float0 |
| 38 | + # Compute exclidean distance |
| 39 | + for k in range(X1_cols): |
| 40 | + tmp = X1[i, k] - X2[j, k] |
| 41 | + d += tmp * tmp |
| 42 | + # Write computed distance to distance matrix |
| 43 | + D[i, j] = dpnp.sqrt(d) |
| 44 | + |
| 45 | + q = dpctl.SyclQueue() |
| 46 | + X1 = dpnp.ones((100, 2), sycl_queue=q) |
| 47 | + X2 = dpnp.ones((100, 2), sycl_queue=q) |
| 48 | + D = dpnp.empty((100, 100), sycl_queue=q) |
| 49 | + |
| 50 | + try: |
| 51 | + pairwise_distance(X1, X2, D) |
| 52 | + except: |
| 53 | + pytest.fail("Failed to compile prange loop for pairwise distance calc") |
0 commit comments