|
| 1 | +// Copyright (c) 2020 Jisang Yoon |
| 2 | +// All rights reserved. |
| 3 | +// |
| 4 | +// This source code is licensed under the Apache 2.0 license found in the |
| 5 | +// LICENSE file in the root directory of this source tree. |
| 6 | +#pragma once |
| 7 | +#include <cuda_fp16.h> |
| 8 | + |
| 9 | +// experimental codes to use half precision |
| 10 | +// not properly working yet.. |
| 11 | +// #define HALF_PRECISION 1 |
| 12 | + |
| 13 | +// #if __CUDA_ARCH__ < 530 |
| 14 | +// #undef HALF_PRECISION |
| 15 | +// #endif |
| 16 | + |
| 17 | +#ifdef HALF_PRECISION |
| 18 | + typedef half cuda_scalar; |
| 19 | + #define mul(x, y) ( __hmul(x, y) ) |
| 20 | + #define add(x, y) ( __hadd(x, y) ) |
| 21 | + #define sub(x, y) ( __hsub(x, y) ) |
| 22 | + #define gt(x, y) ( __hgt(x, y) ) // x > y |
| 23 | + #define ge(x, y) ( __hge(x, y) ) // x >= y |
| 24 | + #define lt(x, y) ( __hlt(x, y) ) // x < y |
| 25 | + #define le(x, y) ( __hle(x, y) ) // x <= y |
| 26 | + #define out_scalar(x) ( __half2float(x) ) |
| 27 | + #define conversion(x) ( __float2half(x) ) |
| 28 | +#else |
| 29 | + typedef float cuda_scalar; |
| 30 | + #define mul(x, y) ( x * y ) |
| 31 | + #define add(x, y) ( x + y ) |
| 32 | + #define sub(x, y) ( x - y ) |
| 33 | + #define gt(x, y) ( x > y ) |
| 34 | + #define ge(x, y) ( x >= y ) |
| 35 | + #define lt(x, y) ( x < y ) |
| 36 | + #define le(x, y) ( x <= y ) |
| 37 | + #define out_scalar(x) ( x ) |
| 38 | + #define conversion(x) ( x ) |
| 39 | +#endif |
| 40 | + |
| 41 | +#define WARP_SIZE 32 |
| 42 | + |
| 43 | +struct Neighbor { |
| 44 | + cuda_scalar distance; |
| 45 | + int nodeid; |
| 46 | + bool checked; |
| 47 | +}; |
| 48 | + |
| 49 | +// to manage the compatibility with hnswlib |
| 50 | +typedef unsigned int tableint; |
| 51 | +typedef unsigned int sizeint; |
| 52 | +typedef float scalar; |
| 53 | +typedef size_t labeltype; |
| 54 | + |
| 55 | +enum DIST_TYPE { |
| 56 | + DOT, |
| 57 | + L2, |
| 58 | +}; |
0 commit comments