diff --git a/.github/workflows/oneapi_githubactions_build.yml b/.github/workflows/oneapi_githubactions_build.yml new file mode 100644 index 0000000000..abbb84d839 --- /dev/null +++ b/.github/workflows/oneapi_githubactions_build.yml @@ -0,0 +1,82 @@ +name: oneapi_ghactions_buildrun + +on: + push: + branches: [ "feature/sycl" ] + +defaults: + run: + shell: bash + +env: + BUILD_TYPE: RELEASE + +jobs: + buildrun: + runs-on: ubuntu-latest + steps: + - name: Install software + run: | + sudo apt update + sudo apt install -y gpg-agent wget + # download the key to system keyring + wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | sudo tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null + # add signed entry to apt sources and configure the APT client to use Intel repository: + echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list + sudo apt update + sudo apt install intel-oneapi-hpc-toolkit + + - name: Setup oneAPI + run: | + source /opt/intel/oneapi/setvars.sh + printenv >> $GITHUB_ENV + which icpx + icpx -v + cat /proc/cpuinfo + + - uses: actions/checkout@v4 + + - name: Ccache for gh actions + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ${{ github.job }} + max-size: 2000M + + - name: Configure CMake + run: > + cmake + -B ${{github.workspace}}/build + -GNinja + -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} + -DCMAKE_C_COMPILER=icx + -DCMAKE_CXX_COMPILER=icpx + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + -DQUDA_TARGET_TYPE=SYCL + -DQUDA_SYCL_TARGETS=spir64_x86_64 + -DCMAKE_CXX_FLAGS="-Wno-unsupported-floating-point-opt" + -DCMAKE_SYCL_FLAGS="-Xs -march=avx512 -Wno-unsupported-floating-point-opt" + -DSYCL_LINK_FLAGS="-Xs -march=avx512 -fsycl-device-code-split=per_kernel -fsycl-max-parallel-link-jobs=4 -flink-huge-device-code" + -DQUDA_DIRAC_COVDEV=OFF + -DQUDA_DIRAC_DISTANCE_PRECONDITIONING=OFF + -DQUDA_MULTIGRID=ON + -DQUDA_INTERFACE_QDPJIT=ON + -DQUDA_FAST_COMPILE_REDUCE=ON + -DQUDA_FAST_COMPILE_DSLASH=ON + -DQUDA_OPENMP=OFF + -DQUDA_MPI=ON + -DQUDA_PRECISION=12 + -DQUDA_DIRAC_DEFAULT_OFF=ON + -DQUDA_DIRAC_STAGGERED=ON + -DQUDA_DIRAC_WILSON=ON + + - name: Build + run: cmake --build ${{github.workspace}}/build + + - name: Install + run: cmake --install ${{github.workspace}}/build + + - name: Run + run: | + cd ${{github.workspace}}/build + #ctest + ctest -E 'invert_test_asqtad_single|invert_test_splitgrid_asqtad_single|unitarize_link_single' diff --git a/CMakeLists.txt b/CMakeLists.txt index 7e93a258de..5c4abcc034 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -214,7 +214,11 @@ if(QUDA_MAX_MULTI_RHS_TILE GREATER QUDA_MAX_MULTI_RHS) message(SEND_ERROR "QUDA_MAX_MULTI_RHS_TILE is greater than QUDA_MAX_MULTI_RHS") endif() -set(QUDA_MAX_KERNEL_ARG_SIZE "4096" CACHE STRING "maximum static size of the kernel arguments in bytes passed to a kernel on the target architecture") +if(${QUDA_TARGET_TYPE} STREQUAL "SYCL") + set(QUDA_MAX_KERNEL_ARG_SIZE "2048" CACHE STRING "maximum static size of the kernel arguments in bytes passed to a kernel on the target architecture") +else() + set(QUDA_MAX_KERNEL_ARG_SIZE "4096" CACHE STRING "maximum static size of the kernel arguments in bytes passed to a kernel on the target architecture") +endif() if(QUDA_MAX_KERNEL_ARG_SIZE GREATER 32764) message(SEND_ERROR "Maximum QUDA_MAX_KERNEL_ARG_SIZE is 32764") endif() diff --git a/cmake/CMakeDetermineSYCLCompiler.cmake b/cmake/CMakeDetermineSYCLCompiler.cmake new file mode 100644 index 0000000000..144b288e92 --- /dev/null +++ b/cmake/CMakeDetermineSYCLCompiler.cmake @@ -0,0 +1,36 @@ +if(NOT CMAKE_SYCL_COMPILER) + set(CMAKE_SYCL_COMPILER ${CMAKE_CXX_COMPILER}) +endif() +mark_as_advanced(CMAKE_SYCL_COMPILER) +message(STATUS "The SYCL compiler is " ${CMAKE_SYCL_COMPILER}) + +if(NOT CMAKE_SYCL_COMPILER_ID_RUN) + set(CMAKE_SYCL_COMPILER_ID_RUN 1) + + # Try to identify the compiler. + set(CMAKE_SYCL_COMPILER_ID) + set(CMAKE_SYCL_PLATFORM_ID) + file(READ ${CMAKE_ROOT}/Modules/CMakePlatformId.h.in CMAKE_SYCL_COMPILER_ID_PLATFORM_CONTENT) + + set(CMAKE_SYCL_COMPILER_ID_TEST_FLAGS_FIRST) + set(CMAKE_SYCL_COMPILER_ID_TEST_FLAGS) + + set(CMAKE_CXX_COMPILER_ID_CONTENT "#if defined(__INTEL_LLVM_COMPILER)\n# define COMPILER_ID \"IntelLLVM\"\n") + string(APPEND CMAKE_CXX_COMPILER_ID_CONTENT "#elif defined(__clang__)\n# define COMPILER_ID \"Clang\"\n") + string(APPEND CMAKE_CXX_COMPILER_ID_CONTENT "#endif\n") + include(${CMAKE_ROOT}/Modules/CMakeDetermineCompilerId.cmake) + CMAKE_DETERMINE_COMPILER_ID(SYCL SYCLFLAGS CMakeCXXCompilerId.cpp) + + _cmake_find_compiler_sysroot(SYCL) +endif() + + +#set(CMAKE_SYCL_COMPILER_ID_TEST_FLAGS_FIRST) +#set(CMAKESYCL_COMPILER_ID_TEST_FLAGS "-c") +#include(${CMAKE_ROOT}/Modules/CMakeDetermineCompilerId.cmake) +#CMAKE_DETERMINE_COMPILER_ID(SYCL SYCLFLAGS CMakeCXXCompilerId.cpp) + +configure_file(${CMAKE_CURRENT_LIST_DIR}/CMakeSYCLCompiler.cmake.in + ${CMAKE_PLATFORM_INFO_DIR}/CMakeSYCLCompiler.cmake) + +set(CMAKE_SYCL_COMPILER_ENV_VAR "SYCL") diff --git a/cmake/CMakeSYCLCompiler.cmake.in b/cmake/CMakeSYCLCompiler.cmake.in new file mode 100644 index 0000000000..2dc0b7acd2 --- /dev/null +++ b/cmake/CMakeSYCLCompiler.cmake.in @@ -0,0 +1,3 @@ +set(CMAKE_SYCL_COMPILER "@CMAKE_SYCL_COMPILER@") +set(CMAKE_SYCL_COMPILER_LOADED 1) +set(CMAKE_SYCL_COMPILER_ENV_VAR "SYCL") diff --git a/cmake/CMakeSYCLInformation.cmake b/cmake/CMakeSYCLInformation.cmake new file mode 100644 index 0000000000..6572616fbf --- /dev/null +++ b/cmake/CMakeSYCLInformation.cmake @@ -0,0 +1,47 @@ +if(NOT CMAKE_SYCL_COMPILE_OPTIONS_PIC) + set(CMAKE_SYCL_COMPILE_OPTIONS_PIC ${CMAKE_CXX_COMPILE_OPTIONS_PIC}) +endif() + +if(NOT CMAKE_SYCL_COMPILE_OPTIONS_PIE) + set(CMAKE_SYCL_COMPILE_OPTIONS_PIE ${CMAKE_CXX_COMPILE_OPTIONS_PIE}) +endif() +if(NOT CMAKE_SYCL_LINK_OPTIONS_PIE) + set(CMAKE_SYCL_LINK_OPTIONS_PIE ${CMAKE_CXX_LINK_OPTIONS_PIE}) +endif() +if(NOT CMAKE_SYCL_LINK_OPTIONS_NO_PIE) + set(CMAKE_SYCL_LINK_OPTIONS_NO_PIE ${CMAKE_CXX_LINK_OPTIONS_NO_PIE}) +endif() + +if(NOT CMAKE_SYCL_OUTPUT_EXTENSION) + set(CMAKE_SYCL_OUTPUT_EXTENSION ${CMAKE_CXX_OUTPUT_EXTENSION}) +endif() + +if(NOT CMAKE_INCLUDE_FLAG_SYCL) + set(CMAKE_INCLUDE_FLAG_SYCL ${CMAKE_INCLUDE_FLAG_CXX}) +endif() + +if(NOT CMAKE_SYCL_COMPILE_OPTIONS_EXPLICIT_LANGUAGE) + set(CMAKE_SYCL_COMPILE_OPTIONS_EXPLICIT_LANGUAGE ${CMAKE_CXX_COMPILE_OPTIONS_EXPLICIT_LANGUAGE}) +endif() + +if(NOT CMAKE_SYCL_DEPENDS_USE_COMPILER) + set(CMAKE_SYCL_DEPENDS_USE_COMPILER ${CMAKE_CXX_DEPENDS_USE_COMPILER}) +endif() + +if(NOT CMAKE_DEPFILE_FLAGS_SYCL) + set(CMAKE_DEPFILE_FLAGS_SYCL ${CMAKE_DEPFILE_FLAGS_CXX}) +endif() + +if(NOT CMAKE_SYCL_DEPFILE_FORMAT) + set(CMAKE_SYCL_DEPFILE_FORMAT ${CMAKE_CXX_DEPFILE_FORMAT}) +endif() + +if(NOT CMAKE_SYCL_COMPILE_OBJECT) + set(CMAKE_SYCL_COMPILE_OBJECT " -o -c ") +endif() + +if(NOT CMAKE_SYCL_LINK_EXECUTABLE) + set(CMAKE_SYCL_LINK_EXECUTABLE " -o ") +endif() + +set(CMAKE_SYCL_INFORMATION_LOADED 1) diff --git a/cmake/CMakeTestSYCLCompiler.cmake b/cmake/CMakeTestSYCLCompiler.cmake new file mode 100644 index 0000000000..e7c7219631 --- /dev/null +++ b/cmake/CMakeTestSYCLCompiler.cmake @@ -0,0 +1 @@ +set(CMAKE_SYCL_COMPILER_WORKS 1 CACHE INTERNAL "") diff --git a/include/array.h b/include/array.h index 3005087c85..e5e65b1493 100644 --- a/include/array.h +++ b/include/array.h @@ -34,6 +34,8 @@ namespace quda return output; } + template constexpr T &elem(array &a, int i) { return a[i]; } + /** * @brief Element-wise maximum of two arrays * @param a first array diff --git a/include/blas_helper.cuh b/include/blas_helper.cuh index 806eef5f5e..fa92ec024b 100644 --- a/include/blas_helper.cuh +++ b/include/blas_helper.cuh @@ -193,10 +193,10 @@ namespace quda norm_t max_[n]; // two-pass to increase ILP (assumes length divisible by two, e.g. complex-valued) #pragma unroll - for (int i = 0; i < n; i++) max_[i] = fmaxf(fabsf((norm_t)v[i].real()), fabsf((norm_t)v[i].imag())); + for (int i = 0; i < n; i++) max_[i] = quda::max(quda::abs((norm_t)v[i].real()), quda::abs((norm_t)v[i].imag())); norm_t scale = 0.0; #pragma unroll - for (int i = 0; i < n; i++) scale = fmaxf(max_[i], scale); + for (int i = 0; i < n; i++) scale = quda::max(max_[i], scale); norm = scale * fixedInvMaxValue::value; return fdividef(fixedMaxValue::value, scale); } @@ -309,7 +309,7 @@ namespace quda memcpy(&vecTmp[6], &norm, sizeof(norm_t)); // pack the norm array vecTmp2; copy_and_scale(vecTmp2, &v_[0], scale_inv); - std::memcpy(&vecTmp, &vecTmp2, sizeof(vecTmp2)); + memcpy(&vecTmp, &vecTmp2, sizeof(vecTmp2)); // second do vectorized copy into memory vector_store(data.spinor, parity * cb_offset + x, vecTmp); } diff --git a/include/clover_field_order.h b/include/clover_field_order.h index dc9a92084a..9d004922f6 100644 --- a/include/clover_field_order.h +++ b/include/clover_field_order.h @@ -865,8 +865,8 @@ namespace quda { if (clover.Order() != QUDA_QDPJIT_CLOVER_ORDER) { errorQuda("Invalid clover order %d for this accessor", clover.Order()); } - offdiag = clover_ ? ((Float **)clover_)[0] : clover.data(inverse)[0]; - diag = clover_ ? ((Float **)clover_)[1] : clover.data(inverse)[1]; + offdiag = clover_ ? reinterpret_cast(clover_)[0] : clover.data(inverse)[0]; + diag = clover_ ? reinterpret_cast(clover_)[1] : clover.data(inverse)[1]; } QudaTwistFlavorType TwistFlavor() const { return twist_flavor; } diff --git a/include/color_spinor_field.h b/include/color_spinor_field.h index 60521dba72..c9cd1ba496 100644 --- a/include/color_spinor_field.h +++ b/include/color_spinor_field.h @@ -494,7 +494,7 @@ namespace quda template auto data() const { if (ghost_only) errorQuda("Not defined for ghost-only field"); - return reinterpret_cast(v.data()); + return static_cast(v.data()); } /** @@ -641,6 +641,7 @@ namespace quda @param[in] gdr_recv Whether we are using GDR on the receive side */ int commsQuery(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false) const; + void commsQuery(int n, int d[], bool done[], bool gdr_send, bool gdr_recv) const; /** @brief Wait on halo communication to complete @@ -878,7 +879,6 @@ namespace quda /** * @brief Print the site vector - * @param[in] a The field we are printing from * @param[in] parity Parity index * @param[in] x_cb Checkerboard space-time index * @param[in] rank The rank we are requesting from (default is rank = 0) diff --git a/include/color_spinor_field_order.h b/include/color_spinor_field_order.h index 2c46c23ea9..b762d56ff2 100644 --- a/include/color_spinor_field_order.h +++ b/include/color_spinor_field_order.h @@ -241,8 +241,9 @@ namespace quda constexpr int M = nSpinBlock * nColor * nVec; #pragma unroll for (int i = 0; i < M; i++) { - vec_t tmp - = vector_load(reinterpret_cast(in + parity * offset_cb), x_cb * N + chi * M + i); + // vec_t tmp + // = vector_load(reinterpret_cast(in + parity * offset_cb), x_cb * N + chi * M + i); + vec_t tmp = vector_load(in + parity * offset_cb, x_cb * N + chi * M + i); memcpy(&out[i], &tmp, sizeof(vec_t)); } } @@ -1061,7 +1062,7 @@ namespace quda for (int i = 0; i < length_ghost / 2; i++) max_[i] = fmaxf((norm_type)fabsf((norm_type)v[i]), (norm_type)fabsf((norm_type)v[i + length_ghost / 2])); #pragma unroll - for (int i = 0; i < length_ghost / 2; i++) scale = fmaxf(max_[i], scale); + for (int i = 0; i < length_ghost / 2; i++) scale = max(max_[i], scale); ghost_norm[2 * dim + dir][parity * faceVolumeCB[dim] + x] = scale * fixedInvMaxValue::value; scale_inv = fdividef(fixedMaxValue::value, scale); } @@ -1203,7 +1204,7 @@ namespace quda for (int i = 0; i < length / 2; i++) max_[i] = fmaxf(fabsf((norm_type)v[i]), fabsf((norm_type)v[i + length / 2])); #pragma unroll - for (int i = 0; i < length / 2; i++) scale = fmaxf(max_[i], scale); + for (int i = 0; i < length / 2; i++) scale = max(max_[i], scale); norm[x + parity * norm_offset] = scale * fixedInvMaxValue::value; scale_inv = fdividef(fixedMaxValue::value, scale); } @@ -1306,10 +1307,10 @@ namespace quda // two-pass to increase ILP (assumes length divisible by two, e.g. complex-valued) #pragma unroll for (int i = 0; i < length_ghost / 2; i++) - max_[i] = fmaxf(fabsf((norm_type)v[i]), fabsf((norm_type)v[i + length_ghost / 2])); + max_[i] = max(abs((norm_type)v[i]), abs((norm_type)v[i + length_ghost / 2])); norm_type scale = 0.0; #pragma unroll - for (int i = 0; i < length_ghost / 2; i++) scale = fmaxf(max_[i], scale); + for (int i = 0; i < length_ghost / 2; i++) scale = max(max_[i], scale); norm_type nrm = scale * fixedInvMaxValue::value; real scale_inv = fdividef(fixedMaxValue::value, scale); @@ -1411,11 +1412,10 @@ namespace quda norm_type max_[length / 2]; // two-pass to increase ILP (assumes length divisible by two, e.g. complex-valued) #pragma unroll - for (int i = 0; i < length / 2; i++) - max_[i] = fmaxf(fabsf((norm_type)v[i]), fabsf((norm_type)v[i + length / 2])); + for (int i = 0; i < length / 2; i++) max_[i] = max(abs((norm_type)v[i]), abs((norm_type)v[i + length / 2])); norm_type scale = 0.0; #pragma unroll - for (int i = 0; i < length / 2; i++) scale = fmaxf(max_[i], scale); + for (int i = 0; i < length / 2; i++) scale = max(max_[i], scale); norm_type nrm = scale * fixedInvMaxValue::value; real scale_inv = fdividef(fixedMaxValue::value, scale); diff --git a/include/comm_quda.h b/include/comm_quda.h index d33e949bdf..0930f7bd79 100644 --- a/include/comm_quda.h +++ b/include/comm_quda.h @@ -422,6 +422,7 @@ namespace quda void comm_start(MsgHandle *mh); void comm_wait(MsgHandle *mh); int comm_query(MsgHandle *mh); + // void comm_query(int n, MsgHandle *mh[], int *outcount, int array_of_indices[]); template void comm_allreduce_sum(T &v); template void comm_allreduce_max(T &v); diff --git a/include/communicator_quda.h b/include/communicator_quda.h index c5ddbe1f75..338ab150ce 100644 --- a/include/communicator_quda.h +++ b/include/communicator_quda.h @@ -771,6 +771,8 @@ namespace quda int comm_query(MsgHandle *mh); + // void comm_query(int n, MsgHandle *mh[], int *outcount, int array_of_indices[]); + template T deterministic_reduce(T *array, int n) { std::sort(array, array + n); // sort reduction into ascending order for deterministic reduction diff --git a/include/convert.h b/include/convert.h index f56751873c..2d1026cb31 100644 --- a/include/convert.h +++ b/include/convert.h @@ -128,6 +128,7 @@ namespace quda } }; +#if 0 /** @brief Fast float-to-integer round used on the device */ @@ -148,6 +149,7 @@ namespace quda return i; } }; +#endif /** @brief Regular double-to-integer round used on the host @@ -156,6 +158,7 @@ namespace quda constexpr int operator()(double d) { return static_cast(rint(d)); } }; +#if 0 /** @brief Fast double-to-integer round used on the device */ @@ -166,6 +169,7 @@ namespace quda return reinterpret_cast(d); } }; +#endif /** @brief Copy function which is trival between floating point diff --git a/include/dslash_helper.cuh b/include/dslash_helper.cuh index da002550ff..884fa31dd9 100644 --- a/include/dslash_helper.cuh +++ b/include/dslash_helper.cuh @@ -503,6 +503,12 @@ namespace quda dslash.template operator()(x_cb, s, parity); } + template + __forceinline__ __device__ void apply_dslash(D &dslash, int x_cb, int s, int parity, bool alive) + { + dslash.template operator()(x_cb, s, parity, alive); + } + #ifdef NVSHMEM_COMMS /** * @brief helper function for nvshmem uber kernel to signal that the interior kernel has completed. diff --git a/include/externals/json.hpp b/include/externals/json.hpp index cb27e05811..443aa9a665 100644 --- a/include/externals/json.hpp +++ b/include/externals/json.hpp @@ -21895,7 +21895,7 @@ inline void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL& j1, nlohmann::NLOHMANN_BASIC /// @brief user-defined string literal for JSON values /// @sa https://json.nlohmann.me/api/basic_json/operator_literal_json/ JSON_HEDLEY_NON_NULL(1) -inline nlohmann::json operator "" _json(const char* s, std::size_t n) +inline nlohmann::json operator""_json(const char* s, std::size_t n) { return nlohmann::json::parse(s, s + n); } @@ -21903,7 +21903,7 @@ inline nlohmann::json operator "" _json(const char* s, std::size_t n) /// @brief user-defined string literal for JSON pointer /// @sa https://json.nlohmann.me/api/basic_json/operator_literal_json_pointer/ JSON_HEDLEY_NON_NULL(1) -inline nlohmann::json::json_pointer operator "" _json_pointer(const char* s, std::size_t n) +inline nlohmann::json::json_pointer operator""_json_pointer(const char* s, std::size_t n) { return nlohmann::json::json_pointer(std::string(s, n)); } diff --git a/include/gauge_field_order.h b/include/gauge_field_order.h index 827dde5bbf..c531c91d18 100644 --- a/include/gauge_field_order.h +++ b/include/gauge_field_order.h @@ -1945,6 +1945,7 @@ namespace quda { LegacyOrder(u, ghost_), volumeCB(u.VolumeCB()) { for (int i = 0; i < 4; i++) gauge[i] = gauge_ ? ((Float **)gauge_)[i] : u.data(i); + // for (int i = 0; i < 4; i++) gauge[i] = gauge_ ? gauge_[i] : u.data(i); } __device__ __host__ inline void load(complex v[length / 2], int x, int dir, int parity, real = 1.0) const @@ -1991,6 +1992,7 @@ namespace quda { LegacyOrder(u, ghost_), volumeCB(u.VolumeCB()) { for (int i = 0; i < 4; i++) gauge[i] = gauge_ ? ((Float **)gauge_)[i] : u.data(i); + // for (int i = 0; i < 4; i++) gauge[i] = gauge_ ? gauge_[i] : u.data(i); } __device__ __host__ inline void load(complex v[length / 2], int x, int dir, int parity, real = 1.0) const diff --git a/include/kernels/clover_outer_product.cuh b/include/kernels/clover_outer_product.cuh index e887e65f0d..65953e4008 100644 --- a/include/kernels/clover_outer_product.cuh +++ b/include/kernels/clover_outer_product.cuh @@ -40,7 +40,7 @@ namespace quda { const ColorSpinorField &p_halo, cvector_ref &x, const ColorSpinorField &x_halo, const std::vector &coeff) : kernel_param(dim3(dim == -1 ? static_cast(x_halo.getDslashConstant().volume_4d_cb) : - x_halo.getDslashConstant().ghostFaceCB[dim], + x_halo.getDslashConstant().ghostFaceCB[dim == -1 ? 0 : dim], x.SiteSubset(), dim == -1 ? 4 : dim)), n_src(p.size()), force(force), diff --git a/include/kernels/coarse_op_kernel.cuh b/include/kernels/coarse_op_kernel.cuh index 7a59bdfae5..f91f5873d6 100644 --- a/include/kernels/coarse_op_kernel.cuh +++ b/include/kernels/coarse_op_kernel.cuh @@ -1382,7 +1382,7 @@ namespace quda { }; template struct storeCoarseSharedAtomic_impl { - template void operator()(Args...) + template void operator()(Args...) { errorQuda("Shared-memory atomic aggregation not supported on host"); } @@ -1402,9 +1402,9 @@ namespace quda { template using Cache = SharedMemoryCache, DimsStaticConditional<2, 1, 1>>; template using Ops = KernelOps>; - template + template inline __device__ void operator()(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, int i0, int j0, - int parity, const Pack &pack, const Ftor &ftor) + int parity, const Pack &pack, const Ftor &ftor, bool active) { using Arg = typename Ftor::Arg; const Arg &arg = ftor.arg; @@ -1468,57 +1468,61 @@ namespace quda { if (tx < Arg::coarseSpin*Arg::coarseSpin && (parity == 0 || arg.parity_flip == 1) ) { + if (!allthreads || active) { #pragma unroll - for (int i = 0; i < TileType::M; i++) { + for (int i = 0; i < TileType::M; i++) { #pragma unroll - for (int j = 0; j < TileType::N; j++) { - if (pack.dir == QUDA_IN_PLACE) { - // same as dir == QUDA_FORWARDS - arg.X_atomic.atomicAdd(0,coarse_parity,coarse_x_cb,s_row,s_col,i0+i,j0+j, - X[i_block0+i][j_block0+j][x_][s_row][s_col]); - } else { - arg.Y_atomic.atomicAdd(dim_index,coarse_parity,coarse_x_cb,s_row,s_col,i0+i,j0+j, - Y[i_block0+i][j_block0+j][x_][s_row][s_col]); - - if (pack.dir == QUDA_BACKWARDS) { - arg.X_atomic.atomicAdd(0,coarse_parity,coarse_x_cb,s_col,s_row,j0+j,i0+i, - conj(X[i_block0+i][j_block0+j][x_][s_row][s_col])); + for (int j = 0; j < TileType::N; j++) { + if (pack.dir == QUDA_IN_PLACE) { + // same as dir == QUDA_FORWARDS + arg.X_atomic.atomicAdd(0, coarse_parity, coarse_x_cb, s_row, s_col, i0 + i, j0 + j, + X[i_block0 + i][j_block0 + j][x_][s_row][s_col]); } else { - arg.X_atomic.atomicAdd(0,coarse_parity,coarse_x_cb,s_row,s_col,i0+i,j0+j, - X[i_block0+i][j_block0+j][x_][s_row][s_col]); - } - - if (!arg.bidirectional) { - if (Arg::fineSpin != 1 && s_row == s_col) arg.X_atomic.atomicAdd(0,coarse_parity,coarse_x_cb,s_row,s_col,i0+i,j0+j, - X[i_block0+i][j_block0+j][x_][s_row][s_col]); - else arg.X_atomic.atomicAdd(0,coarse_parity,coarse_x_cb,s_row,s_col,i0+i,j0+j, - -X[i_block0+i][j_block0+j][x_][s_row][s_col]); - } - } // dir == QUDA_IN_PLACE + arg.Y_atomic.atomicAdd(dim_index, coarse_parity, coarse_x_cb, s_row, s_col, i0 + i, j0 + j, + Y[i_block0 + i][j_block0 + j][x_][s_row][s_col]); + + if (pack.dir == QUDA_BACKWARDS) { + arg.X_atomic.atomicAdd(0, coarse_parity, coarse_x_cb, s_col, s_row, j0 + j, i0 + i, + conj(X[i_block0 + i][j_block0 + j][x_][s_row][s_col])); + } else { + arg.X_atomic.atomicAdd(0, coarse_parity, coarse_x_cb, s_row, s_col, i0 + i, j0 + j, + X[i_block0 + i][j_block0 + j][x_][s_row][s_col]); + } + + if (!arg.bidirectional) { + if (Arg::fineSpin != 1 && s_row == s_col) + arg.X_atomic.atomicAdd(0, coarse_parity, coarse_x_cb, s_row, s_col, i0 + i, j0 + j, + X[i_block0 + i][j_block0 + j][x_][s_row][s_col]); + else + arg.X_atomic.atomicAdd(0, coarse_parity, coarse_x_cb, s_row, s_col, i0 + i, j0 + j, + -X[i_block0 + i][j_block0 + j][x_][s_row][s_col]); + } + } // dir == QUDA_IN_PLACE + } } } } } }; - template + template __device__ __host__ void storeCoarseSharedAtomic(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, - int i0, int j0, int parity, const Ftor &ftor) + int i0, int j0, int parity, const Ftor &ftor, bool active) { using Arg = typename Ftor::Arg; const Arg &arg = ftor.arg; switch (arg.dir) { case QUDA_BACKWARDS: - target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, - Pack(), ftor); + target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, + parity, Pack(), ftor, active); break; case QUDA_FORWARDS: - target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, - Pack(), ftor); + target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, + parity, Pack(), ftor, active); break; case QUDA_IN_PLACE: - target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, - Pack(), ftor); + target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, + parity, Pack(), ftor, active); break; default: break;// do nothing @@ -1605,9 +1609,9 @@ namespace quda { } - template + template __device__ __host__ void computeVUV(const Ftor &ftor, int parity, int x_cb, int i0, int j0, int parity_coarse_, - int coarse_x_cb_) + int coarse_x_cb_, bool active) { using Arg = typename Ftor::Arg; const Arg &arg = ftor.arg; @@ -1634,7 +1638,7 @@ namespace quda { using Ctype = decltype(make_tile_C, false>(arg.vuvTile)); Ctype vuv[Arg::coarseSpin * Arg::coarseSpin]; - multiplyVUV(vuv, arg, parity, x_cb, i0, j0); + if (!allthreads || active) multiplyVUV(vuv, arg, parity, x_cb, i0, j0); if (isDiagonal && !isFromCoarseClover) { #pragma unroll @@ -1642,8 +1646,8 @@ namespace quda { } if (arg.shared_atomic) - storeCoarseSharedAtomic(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, ftor); - else + storeCoarseSharedAtomic(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, ftor, active); + else if (!allthreads || active) storeCoarseGlobalAtomic(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, arg); } @@ -1721,17 +1725,24 @@ namespace quda { @param[in] parity_c_row parity * output color row @param[in] c_col output coarse color column */ - __device__ __host__ inline void operator()(int x_cb, int parity_c_row, int c_col) + template + __device__ __host__ inline void operator()(int x_cb, int parity_c_row, int c_col, bool active = true) { - int parity, parity_coarse, x_coarse_cb, c_row; - target::dispatch(parity_coarse, x_coarse_cb, parity, x_cb, parity_c_row, c_row, c_col, arg); - - if (parity > 1) return; - if (c_row >= arg.vuvTile.M_tiles) return; - if (c_col >= arg.vuvTile.N_tiles) return; - if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) return; - - computeVUV(*this, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, x_coarse_cb); + int parity = 0, parity_coarse = 0, x_coarse_cb = 0, c_row = 0; + if (!allthreads || active) + target::dispatch(parity_coarse, x_coarse_cb, parity, x_cb, parity_c_row, c_row, c_col, arg); + + // if (parity > 1) return; + // if (c_row >= arg.vuvTile.M_tiles) return; + // if (c_col >= arg.vuvTile.N_tiles) return; + // if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) return; + if (parity > 1) active = false; + if (c_row >= arg.vuvTile.M_tiles) active = false; + if (c_col >= arg.vuvTile.N_tiles) active = false; + if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) active = false; + + computeVUV(*this, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, + x_coarse_cb, active); } }; @@ -1751,17 +1762,24 @@ namespace quda { @param[in] parity_c_row parity * output color row @param[in] c_col output coarse color column */ - __device__ __host__ inline void operator()(int x_cb, int parity_c_row, int c_col) + template + __device__ __host__ inline void operator()(int x_cb, int parity_c_row, int c_col, bool active = true) { - int parity, parity_coarse, x_coarse_cb, c_row; - target::dispatch(parity_coarse, x_coarse_cb, parity, x_cb, parity_c_row, c_row, c_col, arg); - - if (parity > 1) return; - if (c_row >= arg.vuvTile.M_tiles) return; - if (c_col >= arg.vuvTile.N_tiles) return; - if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) return; - - computeVUV(*this, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, x_coarse_cb); + int parity = 0, parity_coarse = 0, x_coarse_cb = 0, c_row = 0; + if (!allthreads || active) + target::dispatch(parity_coarse, x_coarse_cb, parity, x_cb, parity_c_row, c_row, c_col, arg); + + // if (parity > 1) return; + // if (c_row >= arg.vuvTile.M_tiles) return; + // if (c_col >= arg.vuvTile.N_tiles) return; + // if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) return; + if (parity > 1) active = false; + if (c_row >= arg.vuvTile.M_tiles) active = false; + if (c_col >= arg.vuvTile.N_tiles) active = false; + if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) active = false; + + computeVUV(*this, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, + x_coarse_cb, active); } }; diff --git a/include/kernels/dslash_clover_helper.cuh b/include/kernels/dslash_clover_helper.cuh index 94b6bce0d9..54a137893c 100644 --- a/include/kernels/dslash_clover_helper.cuh +++ b/include/kernels/dslash_clover_helper.cuh @@ -203,7 +203,8 @@ namespace quda { } static constexpr const char* filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int x_cb, int src_flavor, int parity) + template + __device__ __host__ inline void operator()(int x_cb, int src_flavor, int parity, bool active = true) { using namespace linalg; // for Cholesky const int clover_parity = arg.nParity == 2 ? parity : arg.parity; @@ -214,15 +215,21 @@ namespace quda { const int flavor = src_flavor % 2; int my_flavor_idx = x_cb + flavor * arg.volumeCB; - fermion in = arg.in[src_idx](my_flavor_idx, spinor_parity); - in.toRel(); // change to chiral basis here - + fermion in; int chirality = flavor; // relabel flavor as chirality + Mat A; + if (!allthreads || active) { + in = arg.in[src_idx](my_flavor_idx, spinor_parity); + in.toRel(); // change to chiral basis here + A = arg.clover(x_cb, clover_parity, chirality); + } else { + in = fermion {}; + A = Mat {}; + } + // (C + i mu gamma_5 tau_3 - epsilon tau_1 ) [note: appropriate signs carried in arg.a / arg.b] const complex a(0.0, chirality == 0 ? arg.a : -arg.a); - Mat A = arg.clover(x_cb, clover_parity, chirality); - SharedMemoryCache cache {*this}; half_fermion in_chi[n_flavor]; // flavor array of chirally projected fermion @@ -251,27 +258,32 @@ namespace quda { out_chi[flavor] += arg.b * in_chi[1 - flavor]; } - if (arg.inverse) { - if (arg.dynamic_clover) { - Mat A2 = A.square(); - A2 += arg.a2_minus_b2; - Cholesky, N> cholesky(A2); + if (!allthreads || active) { + if (arg.inverse) { + if (arg.dynamic_clover) { + Mat A2 = A.square(); + A2 += arg.a2_minus_b2; + Cholesky, N> cholesky(A2); #pragma unroll - for (int flavor = 0; flavor < n_flavor; flavor++) - out_chi[flavor] = static_cast(0.25) * cholesky.backward(cholesky.forward(out_chi[flavor])); - } else { - Mat Ainv = arg.cloverInv(x_cb, clover_parity, chirality); + for (int flavor = 0; flavor < n_flavor; flavor++) + out_chi[flavor] = static_cast(0.25) * cholesky.backward(cholesky.forward(out_chi[flavor])); + } else { + Mat Ainv = arg.cloverInv(x_cb, clover_parity, chirality); #pragma unroll - for (int flavor = 0; flavor < n_flavor; flavor++) - out_chi[flavor] = static_cast(2.0) * (Ainv * out_chi[flavor]); + for (int flavor = 0; flavor < n_flavor; flavor++) + out_chi[flavor] = static_cast(2.0) * (Ainv * out_chi[flavor]); + } } } swizzle(out_chi, chirality); // undo the flavor-chirality swizzle - fermion out = out_chi[0].chiral_reconstruct(0) + out_chi[1].chiral_reconstruct(1); - out.toNonRel(); // change basis back - arg.out[src_idx](my_flavor_idx, spinor_parity) = out; + if (!allthreads || active) { + fermion out = out_chi[0].chiral_reconstruct(0) + out_chi[1].chiral_reconstruct(1); + out.toNonRel(); // change basis back + + arg.out[src_idx](my_flavor_idx, spinor_parity) = out; + } } }; } diff --git a/include/kernels/dslash_coarse.cuh b/include/kernels/dslash_coarse.cuh index 086a7def06..f630e66c0d 100644 --- a/include/kernels/dslash_coarse.cuh +++ b/include/kernels/dslash_coarse.cuh @@ -339,7 +339,8 @@ namespace quda { } static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int x_cb_color_offset, int src_parity, int sMd) + template + __device__ __host__ inline void operator()(int x_cb_color_offset, int src_parity, int sMd, bool active = true) { int x_cb = x_cb_color_offset; int color_offset = 0; @@ -368,11 +369,16 @@ namespace quda { typename CoarseDslashParams::array_t out {}; if (Arg::dslash) { - applyDslash(out, dim, dir, x_cb, src_idx, parity, s, color_block, color_offset, arg); + if (!allthreads || active) { + applyDslash(out, dim, dir, x_cb, src_idx, parity, s, color_block, color_offset, arg); + } target::dispatch(out, dir, dim, *this); } - if (doBulk() && Arg::clover && dir==0 && dim==0) applyClover(out, arg, x_cb, src_idx, parity, s, color_block, color_offset); + if (!allthreads || active) { + if (doBulk() && Arg::clover && dir == 0 && dim == 0) + applyClover(out, arg, x_cb, src_idx, parity, s, color_block, color_offset); + } if (dir==0 && dim==0) { const int my_spinor_parity = (arg.nParity == 2) ? parity : 0; @@ -380,13 +386,17 @@ namespace quda { // reduce down to the first group of column-split threads out = warp_combine(out); + if (!allthreads || active) { #pragma unroll - for (int color_local=0; color_local()) arg.out[src_idx](my_spinor_parity, x_cb, s, c) = out[color_local]; - else arg.out[src_idx](my_spinor_parity, x_cb, s, c) += out[color_local]; + for (int color_local = 0; color_local < Mc; color_local++) { + int c = color_block + color_local; // global color index + if (color_offset == 0) { + // if not halo we just store, else we accumulate + if (doBulk()) + arg.out[src_idx](my_spinor_parity, x_cb, s, c) = out[color_local]; + else + arg.out[src_idx](my_spinor_parity, x_cb, s, c) += out[color_local]; + } } } } diff --git a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh index 46e0ae876a..439cce5433 100644 --- a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh +++ b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh @@ -73,8 +73,8 @@ namespace quda template constexpr domainWall4DFusedM5(const Ftor &ftor) : KernelOpsT(ftor), arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation - template - __device__ __host__ __forceinline__ void operator()(int idx, int src_s, int parity) + template + __device__ __host__ __forceinline__ void operator()(int idx, int src_s, int parity, bool alive = true) { typedef typename mapper::type real; typedef ColorSpinor Vector; @@ -82,73 +82,74 @@ namespace quda int src_idx = src_s / arg.Ls; int s = src_s % arg.Ls; - bool active - = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) + bool active = mykernel_type != EXTERIOR_KERNEL_ALL; // is thread active (non-trival for fused kernel only) int thread_dim; // which dimension is thread working on (fused kernel only) auto coord = getCoords(arg, idx, s, parity, thread_dim); const int my_spinor_parity = arg.nParity == 2 ? parity : 0; Vector stencil_out; - applyWilson(stencil_out, arg, coord, parity, idx, thread_dim, active, src_idx); + if (!allthreads || alive) { + applyWilson(stencil_out, arg, coord, parity, idx, thread_dim, active, src_idx); + } Vector out; - constexpr bool shared = true; // Use shared memory - // In the following `x_cb` are all passed as `x_cb = 0`, since it will not be used if `shared = true`, and `shared = true` - if (active) { - - /****** - * Apply M5pre - */ - if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE) { - constexpr bool sync = false; - out = d5(*this, stencil_out, - my_spinor_parity, 0, s, src_idx); - } + if (allthreads||active) { + /****** + * Apply M5pre + */ + if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE) { + constexpr bool sync = false; + out = d5 + (*this, stencil_out, my_spinor_parity, 0, s, src_idx, alive&&active); + } } int xs = coord.x_cb + s * arg.dc.volume_4d_cb; if (Arg::dslash5_type == Dslash5Type::M5_INV_MOBIUS_M5_INV_DAG) { - /****** - * Apply the two M5inv's: - * this is actually y = 1 * x - kappa_b^2 * m5inv * D4 * in - * out = m5inv-dagger * y - */ - if (active) { - constexpr bool sync = false; - out = variableInv( - *this, stencil_out, my_spinor_parity, 0, s, src_idx); + /****** + * Apply the two M5inv's: + * this is actually y = 1 * x - kappa_b^2 * m5inv * D4 * in + * out = m5inv-dagger * y + */ + if (allthreads||active) { + constexpr bool sync = false; + out = variableInv + (*this, stencil_out, my_spinor_parity, 0, s, src_idx, alive&&active); } - Vector aggregate_external; - if (xpay && mykernel_type == INTERIOR_KERNEL) { - Vector x = arg.x[src_idx](xs, my_spinor_parity); - out = x + arg.a_5[s] * out; - } else if (mykernel_type != INTERIOR_KERNEL && active) { - Vector y = arg.y[src_idx](xs, my_spinor_parity); - aggregate_external = xpay ? arg.a_5[s] * out : out; - out = y + aggregate_external; - } + if (!allthreads||alive) { + Vector aggregate_external; + if (xpay && mykernel_type == INTERIOR_KERNEL) { + Vector x = arg.x[src_idx](xs, my_spinor_parity); + out = x + arg.a_5[s] * out; + } else if (mykernel_type != INTERIOR_KERNEL && active) { + Vector y = arg.y[src_idx](xs, my_spinor_parity); + aggregate_external = xpay ? arg.a_5[s] * out : out; + out = y + aggregate_external; + } - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.y[src_idx](xs, my_spinor_parity) = out; + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.y[src_idx](xs, my_spinor_parity) = out; - if (mykernel_type != INTERIOR_KERNEL && active) { - Vector x = arg.out[src_idx](xs, my_spinor_parity); - out = x + aggregate_external; + if (mykernel_type != INTERIOR_KERNEL && active) { + Vector x = arg.out[src_idx](xs, my_spinor_parity); + out = x + aggregate_external; + } } bool complete = isComplete(arg, coord); - if (complete && active) { - constexpr bool sync = true; - constexpr bool this_dagger = true; - // Then we apply the second m5inv-dag - out = variableInv( - *this, out, my_spinor_parity, 0, s, src_idx); - } + if (allthreads || (complete && active)) { + constexpr bool sync = true; + constexpr bool this_dagger = true; + // Then we apply the second m5inv-dag + auto tmp = variableInv + (*this, out, my_spinor_parity, 0, s, src_idx, alive && complete && active); + if (alive && complete && active) out = tmp; + } } else if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS || Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE_M5_MOB) { @@ -159,25 +160,28 @@ namespace quda * or out = m5mob * x - kappa_b^2 * m5pre *D4 * in (Dslash5Type::DSLASH5_PRE_MOBIUS_M5_MOBIUS) */ - if (active) { - if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS) { out = stencil_out; } + if (allthreads || active) { + if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS) { out = stencil_out; } - if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE_M5_MOB) { - constexpr bool sync = false; - out = d5( - *this, stencil_out, my_spinor_parity, 0, s, src_idx); - } - } + if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE_M5_MOB) { + constexpr bool sync = false; + out = d5(*this, stencil_out, my_spinor_parity, 0, s, src_idx, alive && active); + } + } if (xpay && mykernel_type == INTERIOR_KERNEL) { - Vector x = arg.x[src_idx](xs, my_spinor_parity); + Vector x; + if (!allthreads || alive) x = arg.x[src_idx](xs, my_spinor_parity); constexpr bool sync_m5mob = Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS ? false : true; - x = d5( - *this, x, my_spinor_parity, 0, s, src_idx); - out = x + arg.a_5[s] * out; + x = d5(*this, x, my_spinor_parity, 0, s, src_idx, alive); + if (!allthreads || alive) out = x + arg.a_5[s] * out; } else if (mykernel_type != INTERIOR_KERNEL && active) { - Vector x = arg.out[src_idx](xs, my_spinor_parity); - out = x + (xpay ? arg.a_5[s] * out : out); + if (!allthreads || alive) { + Vector x = arg.out[src_idx](xs, my_spinor_parity); + out = x + (xpay ? arg.a_5[s] * out : out); + } } } else { @@ -191,20 +195,22 @@ namespace quda if (Arg::dslash5_type == Dslash5Type::M5_INV_MOBIUS) { // Apply the m5inv. constexpr bool sync = false; - out = variableInv( - *this, stencil_out, my_spinor_parity, 0, s, src_idx); + out = variableInv + (*this, stencil_out, my_spinor_parity, 0, s, src_idx, alive); } - if (xpay && mykernel_type == INTERIOR_KERNEL) { - Vector x = arg.x[src_idx](xs, my_spinor_parity); - out = x + arg.a_5[s] * out; - } else if (mykernel_type != INTERIOR_KERNEL && active) { - Vector x = arg.out[src_idx](xs, my_spinor_parity); - out = x + (xpay ? arg.a_5[s] * out : out); - } + if (!allthreads || alive) { + if (xpay && mykernel_type == INTERIOR_KERNEL) { + Vector x = arg.x[src_idx](xs, my_spinor_parity); + out = x + arg.a_5[s] * out; + } else if (mykernel_type != INTERIOR_KERNEL && active) { + Vector x = arg.out[src_idx](xs, my_spinor_parity); + out = x + (xpay ? arg.a_5[s] * out : out); + } + } bool complete = isComplete(arg, coord); - if (complete && active) { + if (allthreads || (complete && active)) { /****** * First apply M5inv, and then M5pre @@ -212,12 +218,13 @@ namespace quda if (Arg::dslash5_type == Dslash5Type::M5_INV_MOBIUS_M5_PRE) { // Apply the m5inv. constexpr bool sync_m5inv = false; - out = variableInv( - *this, out, my_spinor_parity, 0, s, src_idx); + auto tmp = variableInv + (*this, out, my_spinor_parity, 0, s, src_idx, alive && complete && active); // Apply the m5pre. constexpr bool sync_m5pre = true; - out = d5(*this, out, my_spinor_parity, - 0, s, src_idx); + tmp = d5 + (*this, tmp, my_spinor_parity, 0, s, src_idx, alive && complete && active); + if (alive && complete && active) out = tmp; } /****** @@ -226,16 +233,17 @@ namespace quda if (Arg::dslash5_type == Dslash5Type::M5_PRE_MOBIUS_M5_INV) { // Apply the m5pre. constexpr bool sync_m5pre = false; - out = d5(*this, out, my_spinor_parity, - 0, s, src_idx); + auto tmp = d5 + (*this, out, my_spinor_parity, 0, s, src_idx, alive && complete && active); // Apply the m5inv. constexpr bool sync_m5inv = true; - out = variableInv( - *this, out, my_spinor_parity, 0, s, src_idx); + tmp = variableInv + (*this, tmp, my_spinor_parity, 0, s, src_idx, alive && complete && active); + if (alive && complete && active) out = tmp; } } } - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](xs, my_spinor_parity) = out; + if (alive && (mykernel_type != EXTERIOR_KERNEL_ALL || active)) arg.out[src_idx](xs, my_spinor_parity) = out; } }; diff --git a/include/kernels/dslash_domain_wall_m5.cuh b/include/kernels/dslash_domain_wall_m5.cuh index 9ea0419b8a..25433b46e4 100644 --- a/include/kernels/dslash_domain_wall_m5.cuh +++ b/include/kernels/dslash_domain_wall_m5.cuh @@ -215,9 +215,9 @@ namespace quda using Ops = std::conditional_t, NoKernelOps>; }; - template - __device__ __host__ inline Vector d5(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s, int src_idx) + template + __device__ __host__ inline Vector d5(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s, int src_idx, bool alive) { const Arg &arg = ftor.arg; int local_src_idx = target::thread_idx().y / arg.Ls; @@ -240,19 +240,21 @@ namespace quda cache.save(in.project(4, proj_dir)); cache.sync(); } - const int fwd_s = (s + 1) % arg.Ls; - const int fwd_idx = fwd_s * arg.volume_4d_cb + x_cb; - HalfVector half_in; - if constexpr (shared) { - half_in = cache.load(threadIdx.x, local_src_idx * arg.Ls + fwd_s, parity); - } else { - Vector full_in = arg.in[src_idx](fwd_idx, parity); - half_in = full_in.project(4, proj_dir); - } - if (s == arg.Ls - 1) { - out += (-arg.m_f * half_in).reconstruct(4, proj_dir); - } else { - out += half_in.reconstruct(4, proj_dir); + if (!allthreads || alive) { + const int fwd_s = (s + 1) % arg.Ls; + const int fwd_idx = fwd_s * arg.volume_4d_cb + x_cb; + HalfVector half_in; + if constexpr (shared) { + half_in = cache.load(threadIdx.x, local_src_idx * arg.Ls + fwd_s, parity); + } else { + Vector full_in = arg.in[src_idx](fwd_idx, parity); + half_in = full_in.project(4, proj_dir); + } + if (s == arg.Ls - 1) { + out += (-arg.m_f * half_in).reconstruct(4, proj_dir); + } else { + out += half_in.reconstruct(4, proj_dir); + } } } @@ -263,20 +265,22 @@ namespace quda cache.save(in.project(4, proj_dir)); cache.sync(); } - const int back_s = (s + arg.Ls - 1) % arg.Ls; - const int back_idx = back_s * arg.volume_4d_cb + x_cb; - HalfVector half_in; - if constexpr (shared) { - half_in = cache.load(threadIdx.x, local_src_idx * arg.Ls + back_s, parity); - } else { - Vector full_in = arg.in[src_idx](back_idx, parity); - half_in = full_in.project(4, proj_dir); - } - if (s == 0) { - out += (-arg.m_f * half_in).reconstruct(4, proj_dir); - } else { - out += half_in.reconstruct(4, proj_dir); - } + if (!allthreads || alive) { + const int back_s = (s + arg.Ls - 1) % arg.Ls; + const int back_idx = back_s * arg.volume_4d_cb + x_cb; + HalfVector half_in; + if constexpr (shared) { + half_in = cache.load(threadIdx.x, local_src_idx * arg.Ls + back_s, parity); + } else { + Vector full_in = arg.in[src_idx](back_idx, parity); + half_in = full_in.project(4, proj_dir); + } + if (s == 0) { + out += (-arg.m_f * half_in).reconstruct(4, proj_dir); + } else { + out += half_in.reconstruct(4, proj_dir); + } + } } } else { // use_half_vector @@ -291,40 +295,44 @@ namespace quda cache.sync(); } - { // forwards direction - const int fwd_s = (s + 1) % arg.Ls; - const int fwd_idx = fwd_s * arg.volume_4d_cb + x_cb; - const Vector in - = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + fwd_s, parity) : arg.in[src_idx](fwd_idx, parity); - constexpr int proj_dir = dagger ? +1 : -1; - if (s == arg.Ls - 1) { - out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); - } else { - out += in.project(4, proj_dir).reconstruct(4, proj_dir); + if (!allthreads || alive) { + { // forwards direction + const int fwd_s = (s + 1) % arg.Ls; + const int fwd_idx = fwd_s * arg.volume_4d_cb + x_cb; + const Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + fwd_s, parity) : + arg.in[src_idx](fwd_idx, parity); + constexpr int proj_dir = dagger ? +1 : -1; + if (s == arg.Ls - 1) { + out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); + } else { + out += in.project(4, proj_dir).reconstruct(4, proj_dir); + } } - } - { // backwards direction - const int back_s = (s + arg.Ls - 1) % arg.Ls; - const int back_idx = back_s * arg.volume_4d_cb + x_cb; - const Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + back_s, parity) : - arg.in[src_idx](back_idx, parity); - constexpr int proj_dir = dagger ? -1 : +1; - if (s == 0) { - out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); - } else { - out += in.project(4, proj_dir).reconstruct(4, proj_dir); + { // backwards direction + const int back_s = (s + arg.Ls - 1) % arg.Ls; + const int back_idx = back_s * arg.volume_4d_cb + x_cb; + const Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + back_s, parity) : + arg.in[src_idx](back_idx, parity); + constexpr int proj_dir = dagger ? -1 : +1; + if (s == 0) { + out += (-arg.m_f * in.project(4, proj_dir)).reconstruct(4, proj_dir); + } else { + out += in.project(4, proj_dir).reconstruct(4, proj_dir); + } } } } // use_half_vector - if (type == Dslash5Type::DSLASH5_MOBIUS_PRE || type == Dslash5Type::M5_INV_MOBIUS_M5_PRE - || type == Dslash5Type::M5_PRE_MOBIUS_M5_INV) { - Vector diagonal = shared ? in : arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = coeff.alpha(s) * out + coeff.beta(s) * diagonal; - } else if (type == Dslash5Type::DSLASH5_MOBIUS) { - Vector diagonal = shared ? in : arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = coeff.kappa(s) * out + diagonal; + if (!allthreads || alive) { + if (type == Dslash5Type::DSLASH5_MOBIUS_PRE || type == Dslash5Type::M5_INV_MOBIUS_M5_PRE + || type == Dslash5Type::M5_PRE_MOBIUS_M5_INV) { + Vector diagonal = shared ? in : arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = coeff.alpha(s) * out + coeff.beta(s) * diagonal; + } else if (type == Dslash5Type::DSLASH5_MOBIUS) { + Vector diagonal = shared ? in : arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = coeff.kappa(s) * out + diagonal; + } } return out; @@ -346,7 +354,8 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s Ls dimension coordinate */ - __device__ __host__ inline void operator()(int x_cb, int src_s, int parity) + template + __device__ __host__ inline void operator()(int x_cb, int src_s, int parity, bool alive = true) { using real = typename Arg::real; coeff_type::value, Arg> coeff(arg); @@ -358,22 +367,24 @@ namespace quda constexpr bool sync = false; constexpr bool shared = false; - Vector out = d5(*this, Vector(), parity, x_cb, s, src_idx); - - if (Arg::xpay) { - if (Arg::type == Dslash5Type::DSLASH5_DWF) { - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = x + arg.a * out; - } else if (Arg::type == Dslash5Type::DSLASH5_MOBIUS_PRE) { - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = x + coeff.a(s) * out; - } else if (Arg::type == Dslash5Type::DSLASH5_MOBIUS) { - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = coeff.a(s) * x + out; + Vector out = d5(*this, Vector(), parity, x_cb, s, src_idx, alive); + + if (!allthreads || alive) { + if (Arg::xpay) { + if (Arg::type == Dslash5Type::DSLASH5_DWF) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = x + arg.a * out; + } else if (Arg::type == Dslash5Type::DSLASH5_MOBIUS_PRE) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = x + coeff.a(s) * out; + } else if (Arg::type == Dslash5Type::DSLASH5_MOBIUS) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = coeff.a(s) * x + out; + } } - } - arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + } } }; @@ -398,9 +409,9 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s_ Ls dimension coordinate */ - template + template __device__ __host__ inline Vector constantInv(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s_, - int src_idx) + int src_idx, bool alive) { using Arg = typename Ftor::Arg; const Arg &arg = ftor.arg; @@ -421,23 +432,25 @@ namespace quda Vector out; - for (int s = 0; s < arg.Ls; s++) { + if (!allthreads || alive) { + for (int s = 0; s < arg.Ls; s++) { - Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : - arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : + arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - { - int exp = s_ < s ? arg.Ls - s + s_ : s_ - s; - real factorR = inv * fpow(k, exp) * (s_ < s ? -arg.m_f : static_cast(1.0)); - constexpr int proj_dir = dagger ? -1 : +1; - out += factorR * (in.project(4, proj_dir)).reconstruct(4, proj_dir); - } + { + int exp = s_ < s ? arg.Ls - s + s_ : s_ - s; + real factorR = inv * fpow(k, exp) * (s_ < s ? -arg.m_f : static_cast(1.0)); + constexpr int proj_dir = dagger ? -1 : +1; + out += factorR * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + } - { - int exp = s_ > s ? arg.Ls - s_ + s : s - s_; - real factorL = inv * fpow(k, exp) * (s_ > s ? -arg.m_f : static_cast(1.0)); - constexpr int proj_dir = dagger ? +1 : -1; - out += factorL * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + { + int exp = s_ > s ? arg.Ls - s_ + s : s - s_; + real factorL = inv * fpow(k, exp) * (s_ > s ? -arg.m_f : static_cast(1.0)); + constexpr int proj_dir = dagger ? +1 : -1; + out += factorL * (in.project(4, proj_dir)).reconstruct(4, proj_dir); + } } } @@ -467,9 +480,9 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s_ Ls dimension coordinate */ - template + template __device__ __host__ inline Vector variableInv(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s_, - int src_idx) + int src_idx, bool alive) { const Arg &arg = ftor.arg; int local_src_idx = target::thread_idx().y / arg.Ls; @@ -486,30 +499,32 @@ namespace quda { // first do R constexpr int proj_dir = dagger ? -1 : +1; - if (shared) { - if (sync) { cache.sync(); } + if constexpr (shared) { + if constexpr (sync) { cache.sync(); } cache.save(in.project(4, proj_dir)); cache.sync(); } - int s = s_; - auto R = coeff.inv(); - HalfVector r; - for (int s_count = 0; s_count < arg.Ls; s_count++) { - auto factorR = (s_ < s ? -arg.m_f * R : R); - - if (shared) { - r += factorR * cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity); - } else { - Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - r += factorR * in.project(4, proj_dir); - } - - R *= coeff.kappa(s); - s = (s + arg.Ls - 1) % arg.Ls; - } - - out += r.reconstruct(4, proj_dir); + if (!allthreads || alive) { + int s = s_; + auto R = coeff.inv(); + HalfVector r; + for (int s_count = 0; s_count < arg.Ls; s_count++) { + auto factorR = (s_ < s ? -arg.m_f * R : R); + + if (shared) { + r += factorR * cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity); + } else { + Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + r += factorR * in.project(4, proj_dir); + } + + R *= coeff.kappa(s); + s = (s + arg.Ls - 1) % arg.Ls; + } + + out += r.reconstruct(4, proj_dir); + } } { // second do L @@ -520,24 +535,26 @@ namespace quda cache.sync(); } - int s = s_; - auto L = coeff.inv(); - HalfVector l; - for (int s_count = 0; s_count < arg.Ls; s_count++) { - auto factorL = (s_ > s ? -arg.m_f * L : L); - - if (shared) { - l += factorL * cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity); - } else { - Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - l += factorL * in.project(4, proj_dir); - } - - L *= coeff.kappa(s); - s = (s + 1) % arg.Ls; - } - - out += l.reconstruct(4, proj_dir); + if (!allthreads || alive) { + int s = s_; + auto L = coeff.inv(); + HalfVector l; + for (int s_count = 0; s_count < arg.Ls; s_count++) { + auto factorL = (s_ > s ? -arg.m_f * L : L); + + if (shared) { + l += factorL * cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity); + } else { + Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + l += factorL * in.project(4, proj_dir); + } + + L *= coeff.kappa(s); + s = (s + 1) % arg.Ls; + } + + out += l.reconstruct(4, proj_dir); + } } } else { // use_half_vector using Cache = std::conditional_t, const Ftor &>; @@ -548,44 +565,46 @@ namespace quda cache.sync(); } - { // first do R - constexpr int proj_dir = dagger ? -1 : +1; + if (!allthreads || alive) { + { // first do R + constexpr int proj_dir = dagger ? -1 : +1; + + int s = s_; + auto R = coeff.inv(); + HalfVector r; + for (int s_count = 0; s_count < arg.Ls; s_count++) { + auto factorR = (s_ < s ? -arg.m_f * R : R); - int s = s_; - auto R = coeff.inv(); - HalfVector r; - for (int s_count = 0; s_count < arg.Ls; s_count++) { - auto factorR = (s_ < s ? -arg.m_f * R : R); + Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : + arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + r += factorR * in.project(4, proj_dir); - Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : - arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - r += factorR * in.project(4, proj_dir); + R *= coeff.kappa(s); + s = (s + arg.Ls - 1) % arg.Ls; + } - R *= coeff.kappa(s); - s = (s + arg.Ls - 1) % arg.Ls; + out += r.reconstruct(4, proj_dir); } - out += r.reconstruct(4, proj_dir); - } + { // second do L + constexpr int proj_dir = dagger ? +1 : -1; - { // second do L - constexpr int proj_dir = dagger ? +1 : -1; + int s = s_; + auto L = coeff.inv(); + HalfVector l; + for (int s_count = 0; s_count < arg.Ls; s_count++) { + auto factorL = (s_ > s ? -arg.m_f * L : L); - int s = s_; - auto L = coeff.inv(); - HalfVector l; - for (int s_count = 0; s_count < arg.Ls; s_count++) { - auto factorL = (s_ > s ? -arg.m_f * L : L); + Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : + arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); + l += factorL * in.project(4, proj_dir); - Vector in = shared ? cache.load(threadIdx.x, local_src_idx * arg.Ls + s, parity) : - arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - l += factorL * in.project(4, proj_dir); + L *= coeff.kappa(s); + s = (s + 1) % arg.Ls; + } - L *= coeff.kappa(s); - s = (s + 1) % arg.Ls; + out += l.reconstruct(4, proj_dir); } - - out += l.reconstruct(4, proj_dir); } } // use_half_vector @@ -618,7 +637,8 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s Ls dimension coordinate */ - __device__ __host__ inline void operator()(int x_cb, int src_s, int parity) + template + __device__ __host__ inline void operator()(int x_cb, int src_s, int parity, bool alive = true) { constexpr int nSpin = 4; using real = typename Arg::real; @@ -628,21 +648,25 @@ namespace quda int src_idx = src_s / arg.Ls; int s = src_s % arg.Ls; - Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); - Vector out; + Vector in, out; + if (!allthreads || alive) { in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); } constexpr bool sync = false; if constexpr (mobius_m5::var_inverse()) { // zMobius, must call variableInv - out = variableInv(*this, in, parity, x_cb, s, src_idx); + out + = variableInv(*this, in, parity, x_cb, s, src_idx, alive); } else { - out = constantInv(*this, in, parity, x_cb, s, src_idx); + out + = constantInv(*this, in, parity, x_cb, s, src_idx, alive); } - if (Arg::xpay) { - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = x + coeff.a(s) * out; - } + if (!allthreads || alive) { + if (Arg::xpay) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = x + coeff.a(s) * out; + } - arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + } } }; diff --git a/include/kernels/dslash_mobius_eofa.cuh b/include/kernels/dslash_mobius_eofa.cuh index 49e65da6d7..3e2bb4e647 100644 --- a/include/kernels/dslash_mobius_eofa.cuh +++ b/include/kernels/dslash_mobius_eofa.cuh @@ -110,7 +110,8 @@ namespace quda } static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int x_cb, int src_s, int parity) + template + __device__ __host__ inline void operator()(int x_cb, int src_s, int parity, bool alive = true) { using real = typename Arg::real; typedef ColorSpinor Vector; @@ -121,7 +122,7 @@ namespace quda SharedMemoryCache cache {*this}; Vector out; - cache.save(arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity)); + if (!allthreads || alive) { cache.save(arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity)); } cache.sync(); auto Ls = arg.Ls; @@ -165,11 +166,13 @@ namespace quda } if (Arg::xpay) { // really axpy - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = arg.a * x + out; - } - } - arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + if (!allthreads || alive) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = arg.a * x + out; + } + } + } + if (!allthreads || alive) { arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; } } }; @@ -196,7 +199,8 @@ namespace quda } static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int x_cb, int src_s, int parity) + template + __device__ __host__ inline void operator()(int x_cb, int src_s, int parity, bool alive = true) { using real = typename Arg::real; typedef ColorSpinor Vector; @@ -206,7 +210,7 @@ namespace quda const auto sherman_morrison = arg.sherman_morrison; SharedMemoryCache cache {*this}; - cache.save(arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity)); + if (!allthreads || alive) { cache.save(arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity)); } cache.sync(); Vector out; @@ -233,10 +237,12 @@ namespace quda } } if (Arg::xpay) { // really axpy - Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); - out = x + arg.a * out; + if (!allthreads || alive) { + Vector x = arg.x[src_idx](s * arg.volume_4d_cb + x_cb, parity); + out = x + arg.a * out; + } } - arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; + if (!allthreads || alive) { arg.out[src_idx](s * arg.volume_4d_cb + x_cb, parity) = out; } } }; diff --git a/include/kernels/dslash_ndeg_twisted_clover.cuh b/include/kernels/dslash_ndeg_twisted_clover.cuh index 049f129d14..625bf68778 100644 --- a/include/kernels/dslash_ndeg_twisted_clover.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover.cuh @@ -14,7 +14,7 @@ namespace quda static constexpr int length = (nSpin / (nSpin / 2)) * 2 * nColor * nColor * (nSpin / 2) * (nSpin / 2) / 2; typedef typename clover_mapper::type C; typedef typename mapper::type real; - + const C A; /** the clover field */ real a; /** this is the Wilson-dslash scale factor */ real b; /** this is the chiral twist factor */ @@ -58,8 +58,8 @@ namespace quda out(x) = M*in = a * D * in + (A(x) + i*b*gamma_5*tau_3 + c*tau_1)*x Note this routine only exists in xpay form. */ - template - __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity) + template + __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity, bool alive = true) { typedef typename mapper::type real; typedef ColorSpinor Vector; @@ -67,9 +67,8 @@ namespace quda int src_idx = src_flavor / 2; int flavor = src_flavor % 2; - bool active - = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) - int thread_dim; // which dimension is thread working on (fused kernel only) + bool active = mykernel_type != EXTERIOR_KERNEL_ALL; // is thread active (non-trival for fused kernel only) + int thread_dim; // which dimension is thread working on (fused kernel only) auto coord = getCoords(arg, idx, flavor, parity, thread_dim); @@ -77,53 +76,64 @@ namespace quda const int my_flavor_idx = coord.x_cb + flavor * arg.dc.volume_4d_cb; Vector out; - if (arg.dd_out.isZero(coord)) { - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; - return; + if (!allthreads || alive) { + if (arg.dd_out.isZero(coord)) { + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + if constexpr (!allthreads) return; + else alive = false; + } } - // defined in dslash_wilson.cuh - applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); + if (!allthreads || alive) { + // defined in dslash_wilson.cuh + applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); + } if constexpr (mykernel_type == INTERIOR_KERNEL) { if (arg.dd_x.isZero(coord)) { - out = arg.a * out; + if (!allthreads || alive) out = arg.a * out; } else { - // apply the chiral and flavor twists - // use consistent load order across s to ensure better cache locality - Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); SharedMemoryCache cache {*this}; - cache.save(x); + Vector tmp; + if (!allthreads || alive) { + // apply the chiral and flavor twists + // use consistent load order across s to ensure better cache locality + Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); + cache.save(x); - x.toRel(); // switch to chiral basis + x.toRel(); // switch to chiral basis - Vector tmp; #pragma unroll - for (int chirality = 0; chirality < 2; chirality++) { - constexpr int n = Arg::nColor * Arg::nSpin / 2; - HMatrix A = arg.A(coord.x_cb, parity, chirality); - HalfVector x_chi = x.chiral_project(chirality); - HalfVector Ax_chi = A * x_chi; - // i * mu * gamma_5 * tau_3 - const complex b(0.0, (chirality ^ flavor) == 0 ? static_cast(arg.b) : -static_cast(arg.b)); - Ax_chi += b * x_chi; - tmp += Ax_chi.chiral_reconstruct(chirality); + for (int chirality = 0; chirality < 2; chirality++) { + constexpr int n = Arg::nColor * Arg::nSpin / 2; + HMatrix A = arg.A(coord.x_cb, parity, chirality); + HalfVector x_chi = x.chiral_project(chirality); + HalfVector Ax_chi = A * x_chi; + // i * mu * gamma_5 * tau_3 + const complex b(0.0, + (chirality ^ flavor) == 0 ? static_cast(arg.b) : -static_cast(arg.b)); + Ax_chi += b * x_chi; + tmp += Ax_chi.chiral_reconstruct(chirality); + } + + tmp.toNonRel(); + // tmp += (c * tau_1) * x } - - tmp.toNonRel(); - // tmp += (c * tau_1) * x cache.sync(); - tmp += arg.c * cache.load_y(target::thread_idx().y + 1 - 2 * flavor); + if (!allthreads || alive) { + tmp += arg.c * cache.load_y(target::thread_idx().y + 1 - 2 * flavor); - // add the Wilson part with normalisation - out = tmp + arg.a * out; + // add the Wilson part with normalisation + out = tmp + arg.a * out; + } } } else if (active) { Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); out = x + arg.a * out; } - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + if (!allthreads || alive) + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; } }; diff --git a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh index 4b1a470db7..fb61259d65 100644 --- a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh @@ -13,7 +13,7 @@ namespace quda using WilsonArg::nSpin; static constexpr int length = (nSpin / (nSpin / 2)) * 2 * nColor * nColor * (nSpin / 2) * (nSpin / 2) / 2; static constexpr bool dynamic_clover = clover::dynamic_inverse(); - + typedef typename mapper::type real; typedef typename clover_mapper::type C; const C A; @@ -64,8 +64,8 @@ namespace quda out(x) = M*in = a*(C + i*b*gamma_5*tau_3 + c*tau_1)/(C^2 + b^2 - c^2)*D*x ( xpay == false ) out(x) = M*in = in + a*(C + i*b*gamma_5*tau_3 + c*tau_1)/(C^2 + b^2 - c^2)*D*x ( xpay == true ) */ - template - __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity) + template + __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity, bool alive = true) { using namespace linalg; // for Cholesky typedef typename mapper::type real; @@ -75,98 +75,107 @@ namespace quda int src_idx = src_flavor / 2; int flavor = src_flavor % 2; - - bool active - = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) - int thread_dim; // which dimension is thread working on (fused kernel only) + bool active = mykernel_type != EXTERIOR_KERNEL_ALL; // is thread active (non-trival for fused kernel only) + int thread_dim; // which dimension is thread working on (fused kernel only) auto coord = getCoords(arg, idx, flavor, parity, thread_dim); const int my_spinor_parity = arg.nParity == 2 ? parity : 0; int my_flavor_idx = coord.x_cb + flavor * arg.dc.volume_4d_cb; Vector out; - if (arg.dd_out.isZero(coord)) { - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; - return; + if (!allthreads || alive) { + if (arg.dd_out.isZero(coord)) { + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + if constexpr (!allthreads) return; + else alive = false; + } } - // defined in dslash_wilson.cuh - applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); + if (!allthreads || alive) { + // defined in dslash_wilson.cuh + applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); - if (mykernel_type != INTERIOR_KERNEL && active) { - // if we're not the interior kernel, then we must sum the partial - Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); - out += x; + if (mykernel_type != INTERIOR_KERNEL && active) { + // if we're not the interior kernel, then we must sum the partial + Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); + out += x; + } } + constexpr int n_flavor = 2; + HalfVector out_chi[n_flavor]; // flavor array of chirally projected fermion if (isComplete(arg, coord) && active) { - out.toRel(); - - constexpr int n_flavor = 2; - HalfVector out_chi[n_flavor]; // flavor array of chirally projected fermion + out.toRel(); #pragma unroll - for (int i = 0; i < n_flavor; i++) out_chi[i] = out.chiral_project(i); - - int chirality = flavor; // relabel flavor as chirality - - SharedMemoryCache cache {*this}; - - auto swizzle = [&](HalfVector x[2], int chirality) { - if (chirality == 0) - cache.save_y(x[1], target::thread_idx().y); - else - cache.save_y(x[0], target::thread_idx().y); - cache.sync(); - if (chirality == 0) - x[1] = cache.load_y(target::thread_idx().y + 1); - else - x[0] = cache.load_y(target::thread_idx().y - 1); - }; - - swizzle(out_chi, chirality); // apply the flavor-chirality swizzle between threads - - // load in the clover matrix - HMat A = arg.A(coord.x_cb, parity, chirality); + for (int i = 0; i < n_flavor; i++) out_chi[i] = out.chiral_project(i); + } - HalfVector A_chi[n_flavor]; + int chirality = flavor; // relabel flavor as chirality + SharedMemoryCache cache {*this}; + auto swizzle = [&](HalfVector x[2], int chirality) { + if (chirality == 0) + cache.save_y(x[1], target::thread_idx().y); + else + cache.save_y(x[0], target::thread_idx().y); + cache.sync(); + if (chirality == 0) + x[1] = cache.load_y(target::thread_idx().y + 1); + else + x[0] = cache.load_y(target::thread_idx().y - 1); + }; + + swizzle(out_chi, chirality); // apply the flavor-chirality swizzle between threads + + if (!allthreads || alive) { + if (isComplete(arg, coord) && active) { + // load in the clover matrix + HMat A = arg.A(coord.x_cb, parity, chirality); + + HalfVector A_chi[n_flavor]; #pragma unroll - for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { - const complex b(0.0, (chirality^flavor_) == 0 ? arg.b : -arg.b); - A_chi[flavor_] = A * out_chi[flavor_]; - A_chi[flavor_] += b * out_chi[flavor_]; - A_chi[flavor_] += arg.c * out_chi[1 - flavor_]; - } - - if constexpr (Arg::dynamic_clover) { - HMat A2 = A.square(); - A2 += arg.b2_minus_c2; - Cholesky, Arg::nColor * Arg::nSpin / 2> cholesky(A2); + for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { + const complex b(0.0, (chirality^flavor_) == 0 ? arg.b : -arg.b); + A_chi[flavor_] = A * out_chi[flavor_]; + A_chi[flavor_] += b * out_chi[flavor_]; + A_chi[flavor_] += arg.c * out_chi[1 - flavor_]; + } + + if constexpr (Arg::dynamic_clover) { + HMat A2 = A.square(); + A2 += arg.b2_minus_c2; + Cholesky, Arg::nColor * Arg::nSpin / 2> cholesky(A2); #pragma unroll - for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { - out_chi[flavor_] = static_cast(0.25) * cholesky.backward(cholesky.forward(A_chi[flavor_])); - } - } else { - HMat A2inv = arg.A2inv(coord.x_cb, parity, chirality); + for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { + out_chi[flavor_] = static_cast(0.25) * cholesky.backward(cholesky.forward(A_chi[flavor_])); + } + } else { + HMat A2inv = arg.A2inv(coord.x_cb, parity, chirality); #pragma unroll - for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { - out_chi[flavor_] = static_cast(2.0) * (A2inv * A_chi[flavor_]); - } - } - - swizzle(out_chi, chirality); // undo the flavor-chirality swizzle - Vector tmp = out_chi[0].chiral_reconstruct(0) + out_chi[1].chiral_reconstruct(1); - tmp.toNonRel(); // switch back to non-chiral basis - - if (xpay && !arg.dd_x.isZero(coord)) { - Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); - out = x + arg.a * tmp; - } else { - // multiplication with a needed here? - out = arg.a * tmp; - } + for (int flavor_ = 0; flavor_ < n_flavor; flavor_++) { + out_chi[flavor_] = static_cast(2.0) * (A2inv * A_chi[flavor_]); + } + } + } } - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + swizzle(out_chi, chirality); // undo the flavor-chirality swizzle + + if (!allthreads || alive) { + if (isComplete(arg, coord) && active) { + Vector tmp = out_chi[0].chiral_reconstruct(0) + out_chi[1].chiral_reconstruct(1); + tmp.toNonRel(); // switch back to non-chiral basis + + if (xpay && !arg.dd_x.isZero(coord)) { + Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); + out = x + arg.a * tmp; + } else { + // multiplication with a needed here? + out = arg.a * tmp; + } + } + + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + } } }; } // namespace quda diff --git a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh index 7effb07ae3..8244eb6787 100644 --- a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh @@ -64,8 +64,8 @@ namespace quda - with xpay: out(x) = M*in = x + a*(1+i*b*gamma_5 + c*tau_1)D * in */ - template - __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity) + template + __device__ __host__ __forceinline__ void operator()(int idx, int src_flavor, int parity, bool alive = true) { typedef typename mapper::type real; typedef ColorSpinor Vector; @@ -73,62 +73,68 @@ namespace quda int src_idx = src_flavor / 2; int flavor = src_flavor % 2; - bool active - = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) - int thread_dim; // which dimension is thread working on (fused kernel only) + bool active = mykernel_type != EXTERIOR_KERNEL_ALL; // is thread active (non-trival for fused kernel only) + int thread_dim; // which dimension is thread working on (fused kernel only) auto coord = getCoords(arg, idx, flavor, parity, thread_dim); const int my_spinor_parity = arg.nParity == 2 ? parity : 0; int my_flavor_idx = coord.x_cb + flavor * arg.dc.volume_4d_cb; Vector out; - if (arg.dd_out.isZero(coord)) { - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; - return; + if (!allthreads || alive) { + if (arg.dd_out.isZero(coord)) { + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + if constexpr (!allthreads) return; + else alive = false; + } } - if (!dagger || Arg::asymmetric) // defined in dslash_wilson.cuh - applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); - else // defined in dslash_twisted_mass_preconditioned - applyWilsonTM(out, arg, coord, parity, idx, thread_dim, active, src_idx); - - if (xpay && mykernel_type == INTERIOR_KERNEL && !arg.dd_x.isZero(coord)) { - - if (!dagger || Arg::asymmetric) { // apply inverse twist which is undone below - // use consistent load order across s to ensure better cache locality - Vector x0 = arg.x[src_idx](coord.x_cb + 0 * arg.dc.volume_4d_cb, my_spinor_parity); - Vector x1 = arg.x[src_idx](coord.x_cb + 1 * arg.dc.volume_4d_cb, my_spinor_parity); - if (flavor == 0) - out += arg.a_inv * (x0 + arg.b_inv * x0.igamma(4) + arg.c_inv * x1); - else - out += arg.a_inv * (x1 - arg.b_inv * x1.igamma(4) + arg.c_inv * x0); - } else { - Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); - out += x; // just directly add since twist already applied in the dslash + if (!allthreads || alive) { + if (!dagger || Arg::asymmetric) // defined in dslash_wilson.cuh + applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); + else // defined in dslash_twisted_mass_preconditioned + applyWilsonTM(out, arg, coord, parity, idx, thread_dim, active, src_idx); + + if (xpay && mykernel_type == INTERIOR_KERNEL && !arg.dd_x.isZero(coord)) { + if constexpr (!dagger || Arg::asymmetric) { // apply inverse twist which is undone below + // use consistent load order across s to ensure better cache locality + Vector x0 = arg.x[src_idx](coord.x_cb + 0 * arg.dc.volume_4d_cb, my_spinor_parity); + Vector x1 = arg.x[src_idx](coord.x_cb + 1 * arg.dc.volume_4d_cb, my_spinor_parity); + if (flavor == 0) + out += arg.a_inv * (x0 + arg.b_inv * x0.igamma(4) + arg.c_inv * x1); + else + out += arg.a_inv * (x1 - arg.b_inv * x1.igamma(4) + arg.c_inv * x0); + } else { + Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); + out += x; // just directly add since twist already applied in the dslash + } + } else if (mykernel_type != INTERIOR_KERNEL && active) { + // if we're not the interior kernel, then we must sum the partial + Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); + out += x; } - - } else if (mykernel_type != INTERIOR_KERNEL && active) { - // if we're not the interior kernel, then we must sum the partial - Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); - out += x; } if constexpr (!dagger || Arg::asymmetric) { // apply A^{-1} to D*in SharedMemoryCache cache {*this}; - if (isComplete(arg, coord) && active) { - // to apply the preconditioner we need to put "out" in shared memory so the other flavor can access it - cache.save(out); - } - - cache.sync(); // safe to sync in here since other threads will exit - if (isComplete(arg, coord) && active) { - if (flavor == 0) - out = arg.a * (out + arg.b * out.igamma(4) + arg.c * cache.load_y(target::thread_idx().y + 1)); - else - out = arg.a * (out - arg.b * out.igamma(4) + arg.c * cache.load_y(target::thread_idx().y - 1)); - } + if (!allthreads || alive) { + if (isComplete(arg, coord) && active) { + // to apply the preconditioner we need to put "out" in shared memory so the other flavor can access it + cache.save(out); + } + } + cache.sync(); // safe to sync here since other threads will exit if allowed, or all be here + if (!allthreads || alive) { + if (isComplete(arg, coord) && active) { + if (flavor == 0) + out = arg.a * (out + arg.b * out.igamma(4) + arg.c * cache.load_y(target::thread_idx().y + 1)); + else + out = arg.a * (out - arg.b * out.igamma(4) + arg.c * cache.load_y(target::thread_idx().y - 1)); + } + } } - if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; + if (!allthreads || alive) + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; } }; diff --git a/include/kernels/multi_blas_core.cuh b/include/kernels/multi_blas_core.cuh index 9e2c503e8b..fb444ace19 100644 --- a/include/kernels/multi_blas_core.cuh +++ b/include/kernels/multi_blas_core.cuh @@ -15,7 +15,8 @@ namespace quda #ifndef QUDA_FAST_COMPILE_REDUCE constexpr bool enable_warp_split() { return false; } #else - constexpr bool enable_warp_split() { return true; } + // constexpr bool enable_warp_split() { return true; } + constexpr bool enable_warp_split() { return false; } #endif /** diff --git a/include/register_traits.h b/include/register_traits.h index 0dd05a598c..b592f8e984 100644 --- a/include/register_traits.h +++ b/include/register_traits.h @@ -33,6 +33,20 @@ namespace quda { double4 y; }; + template std::enable_if_t, T &> constexpr elem(T &a, int i) { return (&a)[i]; } + + template ().x)> + std::enable_if_t, R &> constexpr elem(T &a, int i) + { + return (&a.x)[i]; + } + + template ().x.x), int = 0> + std::enable_if_t, R &> constexpr elem(T &a, int i) + { + return (&a.x.x)[i]; + } + /* Here we use traits to define the greater type used for mixing types of computation involving these types */ diff --git a/include/targets/generic/load_store.h b/include/targets/generic/load_store.h index 3239aeaefc..b3df298eed 100644 --- a/include/targets/generic/load_store.h +++ b/include/targets/generic/load_store.h @@ -1,17 +1,25 @@ #pragma once +#include #include namespace quda { + /** + @brief Element type used for coalesced storage. + */ + template + using atom_t = std::conditional_t>; + /** @brief Non-specialized load operation */ template struct vector_load_impl { template __device__ __host__ inline void operator()(T &value, const void *ptr, int idx) { - value = reinterpret_cast(ptr)[idx]; + // value = reinterpret_cast(ptr)[idx]; + memcpy(&value, static_cast(ptr) + idx, sizeof(value)); } }; @@ -39,11 +47,12 @@ namespace quda template struct vector_store_impl { template __device__ __host__ inline void operator()(void *ptr, int idx, const T &value) { - reinterpret_cast(ptr)[idx] = value; + // reinterpret_cast(ptr)[idx] = value; + memcpy(static_cast(ptr) + idx, &value, sizeof(value)); } }; - template __device__ __host__ inline void vector_store(void *ptr, int idx, const vector_t &value) + template __device__ __host__ inline void vector_storeV(void *ptr, int idx, const vector_t &value) { target::dispatch(ptr, idx, value); } @@ -55,7 +64,9 @@ namespace quda vector_t value_v; static_assert(sizeof(value_a) == sizeof(value_v), "array type and vector type are different sizes"); memcpy(&value_v, &value_a, sizeof(vector_t)); - vector_store(ptr, idx, value_v); + //vector_storeV(ptr, idx, value_v); + scalar_t *a = static_cast(ptr) + N*idx; + memcpy(a, &value_v, sizeof(value_v)); } } // namespace quda diff --git a/include/targets/generic/special_ops.h b/include/targets/generic/special_ops.h new file mode 100644 index 0000000000..9d4683a1d8 --- /dev/null +++ b/include/targets/generic/special_ops.h @@ -0,0 +1,184 @@ +#pragma once +#include +#include + +namespace quda +{ + +#if 0 + // dimensions functors for SharedMemoryCache + struct opDimsBlock { + template static constexpr dim3 dims(dim3 b, const Arg &...) { return b; } + }; + template struct opDimsStatic { + template static constexpr dim3 dims(dim3, const Arg &...) { return dim3(bx,by,bz); } + }; + + // size functors for determining shared memory size + struct opSizeBlock { + template static constexpr unsigned int size(dim3 b, const Arg &...) { + return b.x * b.y * b.z * sizeof(T); + } + }; + struct opSizeBlockDivWarp { + template static constexpr unsigned int size(dim3 b, const Arg &...) { + return ((b.x * b.y * b.z + device::warp_size() - 1)/device::warp_size()) * sizeof(T); + } + }; + template struct opSizeStatic { + template static constexpr unsigned int size(dim3, const Arg &...) { + return S * sizeof(T); + } + }; + template struct opSizeDims { + template static constexpr unsigned int size(dim3 b, const Arg &...arg) { + return opSizeBlock::size(D::dims(b, arg...)); + } + }; +#endif + + template static constexpr unsigned int sharedMemSize(dim3 block, Arg &...arg); + + // alternative to SpecialOps + struct NoSpecialOps { + using SpecialOpsT = NoSpecialOps; + using KernelOpsT = NoSpecialOps; + }; + // SpecialOps forward declaration and base type + template struct SpecialOps; + template using KernelOps = SpecialOps; + template struct SpecialOps_Base { + using SpecialOpsT = SpecialOps; + using KernelOpsT = SpecialOps; + template static constexpr unsigned int shared_mem_size(dim3 block, Arg &...arg) + { + return sharedMemSize(block, arg...); + } + }; + + // getSpecialOps + template struct getSpecialOpsS { + using type = NoSpecialOps; + }; + template struct getSpecialOpsS> { + using type = typename T::SpecialOpsT; + }; + template using getSpecialOps = typename getSpecialOpsS::type; + + // hasSpecialOp: checks if first type matches any of the op + // > + template static constexpr bool hasSpecialOp = false; + template + static constexpr bool hasSpecialOp> = (std::is_same_v || ...); + + // template void checkSpecialOps() { static_assert(hasSpecialOp); } + // template void checkSpecialOps(const Ops &) { + // static_assert(hasSpecialOp); + // } + template void checkSpecialOps(const Ops &) + { + static_assert((hasSpecialOp || ...)); + } + + // forward declarations of op types + struct op_blockSync; + template struct op_warp_combine; + + // only types for convenience + using only_blockSync = SpecialOps; + template using only_warp_combine = SpecialOps>; + + // explicitSpecialOps + template struct explicitSpecialOpsS : std::false_type { + }; + template + struct explicitSpecialOpsS> : std::true_type { + }; + template inline constexpr bool explicitSpecialOps = explicitSpecialOpsS::value; + + // hasSpecialOps +#if 1 + template inline constexpr bool hasSpecialOps = !std::is_same_v, NoSpecialOps>; +#else + template struct hasSpecialOpsImpl { + static constexpr bool value = false; + }; + template struct hasSpecialOpsImpl> { + static constexpr bool value = true; + }; + template inline constexpr bool hasSpecialOps = hasSpecialOpsImpl::value; +#endif + + // checkSpecialOp + template static constexpr void checkSpecialOp() + { + static_assert((std::is_same_v || ...) == true); + } + + // combineOps + template struct combineOpsS { + }; + template struct combineOpsS> { + using type = SpecialOps; + }; + template struct combineOpsS, NoSpecialOps> { + using type = SpecialOps; + }; + template struct combineOpsS, SpecialOps> { + using type = SpecialOps; + }; + template using combineOps = typename combineOpsS::type; + + // sharedMemSize +#if 0 + template struct sharedMemSizeS { + template + static constexpr unsigned int size(dim3 block, Arg &...arg) { + return std::max({sharedMemSizeS::size(block, arg...)...}); + } + }; + template static constexpr unsigned int sharedMemSize(dim3 block, Arg &...arg) { + return sharedMemSizeS::size(block, arg...); + } + template struct sharedMemSizeS> { + template + static constexpr unsigned int size(dim3 block, Arg &...arg) { return sharedMemSize(block, arg...); } + }; + template struct sharedMemSizeS> { + template + static constexpr unsigned int size(dim3 block, Arg &...arg) { return sharedMemSize(block, arg...); } + }; + template struct sharedMemSizeS> { + template + static constexpr unsigned int size(dim3 block, Arg &...arg) { return (sharedMemSize(block, arg...) + ...); } + }; + template struct sharedMemSizeS { // T should be of op_Base + template + static constexpr unsigned int size(dim3 block, Arg &...arg) { + return sharedMemSize(block, arg...); + } + }; +#else + template struct sharedMemSizeS { + template static constexpr unsigned int size(dim3 block, Arg &...arg) + { + // return 0; + return T::shared_mem_size(block, arg...); + } + }; + template <> struct sharedMemSizeS { + template static constexpr unsigned int size(dim3, Arg &...) { return 0; } + }; + template struct sharedMemSizeS> { + template static constexpr unsigned int size(dim3 block, Arg &...arg) + { + return std::max({sharedMemSizeS::size(block, arg...)...}); + } + }; + template static constexpr unsigned int sharedMemSize(dim3 block, Arg &...arg) + { + return sharedMemSizeS::size(block, arg...); + } +#endif + +} // namespace quda diff --git a/include/targets/hip/special_ops_target.h b/include/targets/hip/special_ops_target.h new file mode 100644 index 0000000000..3aea0745a5 --- /dev/null +++ b/include/targets/hip/special_ops_target.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace quda +{ + + // SpecialOps + template struct SpecialOps : SpecialOps_Base { + template constexpr void setSpecialOps(const SpecialOps &) + { + static_assert(std::is_same_v, SpecialOps>); + } + }; + + // op implementations + struct op_blockSync { + template static constexpr unsigned int shared_mem_size(dim3, Arg &...) { return 0; } + }; + + template struct op_warp_combine { + template static constexpr unsigned int shared_mem_size(dim3, Arg &...) { return 0; } + }; + +} // namespace quda diff --git a/include/targets/sycl/FFT_Plans.h b/include/targets/sycl/FFT_Plans.h new file mode 100644 index 0000000000..5b35701201 --- /dev/null +++ b/include/targets/sycl/FFT_Plans.h @@ -0,0 +1,187 @@ +#pragma once + +#ifndef NATIVE_FFT_LIB +#include "../generic/FFT_Plans.h" +#else + +#include +#include +#include +using namespace oneapi::mkl::dft; + +#define FFT_FORWARD 0 +#define FFT_INVERSE 1 + +namespace quda +{ + + typedef struct { + bool isDouble; + union { + descriptor *s; + descriptor *d; + }; + } FFTPlanHandle; + + inline static constexpr bool HaveFFT() { return true; } + + /** + * @brief Call MKL to perform a single-precision complex-to-complex + * transform plan in the transform direction as specified by direction + * parameter + * @param[in] MKL FFT plan + * @param[in] data_in, pointer to the complex input data (in GPU memory) to transform + * @param[out] data_out, pointer to the complex output data (in GPU memory) + * @param[in] direction, the transform direction: CUFFT_FORWARD or CUFFT_INVERSE + */ + inline void ApplyFFT(FFTPlanHandle &plan, float2 *data_in, float2 *data_out, int direction) + { + if (plan.isDouble) { errorQuda("Called single precision FFT with double precision plan\n"); } + sycl::event e; + if (direction == FFT_FORWARD) { + // warningQuda("Forward FFT"); + e = compute_forward(*plan.s, (float *)data_in, (float *)data_out); + } else { + // warningQuda("Backward FFT"); + e = compute_backward(*plan.s, (float *)data_in, (float *)data_out); + } + e.wait(); + // warningQuda("Done FFT"); + } + + /** + * @brief Call CUFFT to perform a double-precision complex-to-complex transform plan in the transform direction + as specified by direction parameter + * @param[in] CUFFT plan + * @param[in] data_in, pointer to the complex input data (in GPU memory) to transform + * @param[out] data_out, pointer to the complex output data (in GPU memory) + * @param[in] direction, the transform direction: CUFFT_FORWARD or CUFFT_INVERSE + */ + inline void ApplyFFT(FFTPlanHandle &plan, double2 *data_in, double2 *data_out, int direction) + { + if (!plan.isDouble) { errorQuda("Called double precision FFT with single precision plan\n"); } + sycl::event e; + if (direction == FFT_FORWARD) { + e = compute_forward(*plan.d, (double *)data_in, (double *)data_out); + } else { + e = compute_backward(*plan.d, (double *)data_in, (double *)data_out); + } + e.wait(); + } + + /** + * @brief Creates a CUFFT plan supporting 4D (1D+3D) data layouts for complex-to-complex + * @param[out] plan, CUFFT plan + * @param[in] size, int4 with lattice size dimensions, (.x,.y,.z,.w) -> (Nx, Ny, Nz, Nt) + * @param[in] dim, 1 for 1D plan along the temporal direction with batch size Nx*Ny*Nz, 3 for 3D plan along Nx, Ny and + * Nz with batch size Nt + * @param[in] precision The precision of the computation + */ + + // inline void SetPlanFFTMany(FFTPlanHandle &plan, int4 size, int dim, QudaPrecision precision) + inline void SetPlanFFTMany(FFTPlanHandle &, int4, int dim, QudaPrecision precision) + { + warningQuda("SetPlanFFTMany %i %i : unimplemented", dim, precision); +#if 0 + auto type = precision == QUDA_DOUBLE_PRECISION ? CUFFT_Z2Z : CUFFT_C2C; + switch (dim) { + case 1: { + int n[1] = {size.w}; + CUFFT_SAFE_CALL(cufftPlanMany(&plan, 1, n, NULL, 1, 0, NULL, 1, 0, type, size.x * size.y * size.z)); + } break; + case 3: { + int n[3] = {size.x, size.y, size.z}; + CUFFT_SAFE_CALL(cufftPlanMany(&plan, 3, n, NULL, 1, 0, NULL, 1, 0, type, size.w)); + } break; + } + CUFFT_SAFE_CALL(cufftSetStream(plan, target::cuda::get_stream(device::get_default_stream()))); +#endif + } + + /** + * @brief Creates a CUFFT plan supporting 4D (2D+2D) data layouts for complex-to-complex + * @param[out] plan, CUFFT plan + * @param[in] size, int4 with lattice size dimensions, (.x,.y,.z,.w) -> (Nx, Ny, Nz, Nt) + * @param[in] dim, 0 for 2D plan in Z-T planes with batch size Nx*Ny, 1 for 2D plan in X-Y planes with batch size Nz*Nt + * @param[in] precision The precision of the computation + */ + inline void SetPlanFFT2DMany(FFTPlanHandle &plan, int4 size, int dim, QudaPrecision precision) + { + // warningQuda("SetPlanFFT2DMany %i %i", dim, precision); + if (precision == QUDA_SINGLE_PRECISION) { + plan.isDouble = false; + if (dim == 0) { + auto q = quda::device::defaultQueue(); + MKL_LONG distance = size.w * size.z; + plan.s = new std::remove_pointer_t({size.w, size.z}); + // plan.s = new std::remove_pointer_t({size.z, size.w}); + plan.s->set_value(config_param::NUMBER_OF_TRANSFORMS, size.x * size.y); + plan.s->set_value(config_param::FWD_DISTANCE, distance); + plan.s->set_value(config_param::BWD_DISTANCE, distance); + plan.s->set_value(config_param::BACKWARD_SCALE, (1.0 / distance)); + plan.s->commit(q); + } else { + auto q = quda::device::defaultQueue(); + MKL_LONG distance = size.x * size.y; + // plan.s = new std::remove_pointer_t({size.x, size.y}); + plan.s = new std::remove_pointer_t({size.y, size.x}); + plan.s->set_value(config_param::NUMBER_OF_TRANSFORMS, size.w * size.z); + plan.s->set_value(config_param::FWD_DISTANCE, distance); + plan.s->set_value(config_param::BWD_DISTANCE, distance); + plan.s->set_value(config_param::BACKWARD_SCALE, (1.0 / distance)); + plan.s->commit(q); + } + } else { + plan.isDouble = true; + if (dim == 0) { + auto q = quda::device::defaultQueue(); + MKL_LONG distance = size.w * size.z; + plan.d = new std::remove_pointer_t({size.w, size.z}); + // plan.d = new std::remove_pointer_t({size.z, size.w}); + plan.d->set_value(config_param::NUMBER_OF_TRANSFORMS, size.x * size.y); + plan.d->set_value(config_param::FWD_DISTANCE, distance); + plan.d->set_value(config_param::BWD_DISTANCE, distance); + plan.d->set_value(config_param::BACKWARD_SCALE, (1.0 / distance)); + plan.d->commit(q); + } else { + auto q = quda::device::defaultQueue(); + MKL_LONG distance = size.x * size.y; + // plan.d = new std::remove_pointer_t({size.x, size.y}); + plan.d = new std::remove_pointer_t({size.y, size.x}); + plan.d->set_value(config_param::NUMBER_OF_TRANSFORMS, size.w * size.z); + plan.d->set_value(config_param::FWD_DISTANCE, distance); + plan.d->set_value(config_param::BWD_DISTANCE, distance); + plan.d->set_value(config_param::BACKWARD_SCALE, (1.0 / distance)); + plan.d->commit(q); + } + } +#if 0 + auto type = precision == QUDA_DOUBLE_PRECISION ? CUFFT_Z2Z : CUFFT_C2C; + switch (dim) { + case 0: { + int n[2] = {size.w, size.z}; + CUFFT_SAFE_CALL(cufftPlanMany(&plan, 2, n, NULL, 1, 0, NULL, 1, 0, type, size.x * size.y)); + } break; + case 1: { + int n[2] = {size.x, size.y}; + CUFFT_SAFE_CALL(cufftPlanMany(&plan, 2, n, NULL, 1, 0, NULL, 1, 0, type, size.z * size.w)); + } break; + } + CUFFT_SAFE_CALL(cufftSetStream(plan, target::cuda::get_stream(device::get_default_stream()))); +#endif + } + + inline void FFTDestroyPlan(FFTPlanHandle &plan) + { + if (plan.isDouble) { + // plan.d->~descriptor(); + delete plan.d; + } else { + // plan.s->~descriptor(); + delete plan.s; + } + } + +} // namespace quda + +#endif // ifndef NATIVE_FFT_LIB diff --git a/include/targets/sycl/aos.h b/include/targets/sycl/aos.h new file mode 100644 index 0000000000..cafb8b529f --- /dev/null +++ b/include/targets/sycl/aos.h @@ -0,0 +1,61 @@ +#pragma once + +// #include + +namespace quda +{ + +#if 0 + template struct subgroup_load_store { + //using atom_t = std::conditional_t>; + //using atom_t = std::conditional_t>; + using atom_t = int; + static_assert(sizeof(T) % 4 == 0, "block_load & block_store do not support sub-word size types"); + static constexpr int n_element = sizeof(T) / sizeof(atom_t); + using vec = atom_t[n_element]; + }; +#endif + + template __host__ __device__ void block_load(T out[n], const T *in) + { + // #pragma unroll + // for (int i = 0; i < n; i++) out[i] = in[i]; + memcpy(out, in, n * sizeof(T)); + // using U = T[n]; + // using LS = subgroup_load_store; + // using V = typename LS::vec; + // using A = typename LS::atom_t; + // constexpr int nv = LS::n_element; + // auto sg = sycl::ext::oneapi::experimental::this_sub_group(); + // auto vin = reinterpret_cast(in) - sg.get_local_id(); + // auto vin = reinterpret_cast(in) - nv*sg.get_local_id(); + // auto t = sg.load(sycl::multi_ptr{vin}); + // #pragma unroll + // for (int i = 0; i < nv; i++) t[i] = sg.load(vin + sg); + // auto vout = sg.load(vin); + // #pragma unroll + // for (int i = 0; i < n; i++) out[i] = vout[i]; + } + + template __host__ __device__ void block_store(T *out, const T in[n]) + { + // #pragma unroll + // for (int i = 0; i < n; i++) out[i] = in[i]; + memcpy(out, in, n * sizeof(T)); + } + + template __host__ __device__ void block_load(T &out, const T *in) + { + // out = *in; + memcpy(&out, in, sizeof(T)); + // auto sg = sycl::ext::oneapi::experimental::this_sub_group(); + // out = sg.load(in - sg.get_local_id()); + } + + template __host__ __device__ void block_store(T *out, const T &in) + { + //*out = in; + memcpy(out, &in, sizeof(T)); + } + +} // namespace quda diff --git a/include/targets/sycl/atomic.cuh b/include/targets/sycl/atomic.cuh new file mode 100644 index 0000000000..ff3e58daac --- /dev/null +++ b/include/targets/sycl/atomic.cuh @@ -0,0 +1,241 @@ +// old depracated version + +#pragma once + +/** + @file atomic.cuh + + @section Description + + Provides definitions of atomic functions that are not native to + CUDA. These are intentionally not declared in the namespace to + avoid confusion when resolving the native atomicAdd functions. + */ + +// inline constexpr auto mo = sycl::memory_order::relaxed; +inline constexpr auto mo = sycl::ext::oneapi::memory_order::acq_rel; +// inline constexpr auto mo = memory_order::seq_cst; +// inline constexpr auto mo = sycl::memory_order::acq_rel; + +// inline constexpr auto ms = sycl::ext::oneapi::memory_scope::system; +inline constexpr auto ms = sycl::ext::oneapi::memory_scope::device; +// inline constexpr auto ms = sycl::memory_scope::system; + +// inline constexpr auto as = sycl::access::address_space::generic_space; +inline constexpr auto as = sycl::access::address_space::global_space; + +template using atomicRef = sycl::ext::oneapi::atomic_ref; +// using atomicRef = sycl::atomic_ref; + +template static inline atomicRef makeAtomicRef(T *address) { return atomicRef(*address); } + +static inline uint __float_as_uint(float x) { return *reinterpret_cast(&x); } + +static inline float __uint_as_float(uint x) { return *reinterpret_cast(&x); } + +static inline unsigned int atomicMax(unsigned int *address, unsigned int val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.fetch_max(val); + return old; +} + +static inline int atomicCAS(int *address, int compare, int val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.compare_exchange_strong(compare, val); + return old; +} +static inline unsigned int atomicCAS(unsigned int *address, unsigned int compare, unsigned int val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.compare_exchange_strong(compare, val); + return old; +} + +/** + @brief Implementation of double-precision atomic addition using compare + and swap. Taken from the CUDA programming guide. + + @param addr Address that stores the atomic variable to be updated + @param val Value to be added to the atomic +*/ +static inline int atomicAdd(int *address, int val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.fetch_add(val); + return old; +} +static inline float atomicAdd(float *address, float val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.fetch_add(val); + return old; +} +static inline double atomicAdd(double *address, double val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.fetch_add(val); + return old; +} + +/** + @brief Implementation of double2 atomic addition using two + double-precision additions. + + @param addr Address that stores the atomic variable to be updated + @param val Value to be added to the atomic +*/ +static inline double2 atomicAdd(double2 *addr, double2 val) +{ + double2 old = *addr; + // This is a necessary evil to avoid conflicts between the atomicAdd + // declared in the CUDA headers which are visible for host + // compilation, which cause a conflict when compiled on clang-cuda. + // As a result we do not support any architecture without native + // double precision atomics on clang-cuda. + old.x = atomicAdd((double *)addr, val.x); + old.y = atomicAdd((double *)addr + 1, val.y); + return old; +} + +/** + @brief Implementation of float2 atomic addition using two + single-precision additions. + + @param addr Address that stores the atomic variable to be updated + @param val Value to be added to the atomic +*/ +static inline float2 atomicAdd(float2 *addr, float2 val) +{ + float2 old = *addr; + old.x = atomicAdd((float *)addr, val.x); + old.y = atomicAdd((float *)addr + 1, val.y); + return old; +} + +/** + @brief Implementation of int2 atomic addition using two + int additions. + + @param addr Address that stores the atomic variable to be updated + @param val Value to be added to the atomic +*/ +static inline int2 atomicAdd(int2 *addr, int2 val) +{ + int2 old = *addr; + old.x = atomicAdd((int *)addr, val.x); + old.y = atomicAdd((int *)addr + 1, val.y); + return old; +} + +union uint32_short2 { + unsigned int i; + short2 s; +}; + +/** + @brief Implementation of short2 atomic addition using compare + and swap. + + @param addr Address that stores the atomic variable to be updated + @param val Value to be added to the atomic +*/ +static inline short2 atomicAdd(short2 *addr, short2 val) +{ + uint32_short2 old, assumed, incremented; + old.s = *addr; + do { + assumed.s = old.s; + incremented.s = make_short2(val.x + assumed.s.x, val.y + assumed.s.y); + old.i = atomicCAS((unsigned int *)addr, assumed.i, incremented.i); + } while (assumed.i != old.i); + + return old.s; +} + +union uint32_char2 { + unsigned short i; + char2 s; +}; + +/** + @brief Implementation of char2 atomic addition using compare + and swap. + + @param addr Address that stores the atomic variable to be updated + @param val Value to be added to the atomic +*/ +static inline char2 atomicAdd(char2 *addr, char2 val) +{ + uint32_char2 old, assumed, incremented; + old.s = *addr; + do { + assumed.s = old.s; + incremented.s = make_char2(val.x + assumed.s.x, val.y + assumed.s.y); + old.i = atomicCAS((unsigned int *)addr, assumed.i, incremented.i); + } while (assumed.i != old.i); + + return old.s; +} + +/** + @brief Implementation of single-precision atomic max using compare + and swap. May not support NaNs properly... + + @param addr Address that stores the atomic variable to be updated + @param val Value to be added to the atomic +*/ +static inline float atomicMax(float *addr, float val) +{ + unsigned int old = __float_as_uint(*addr), assumed; + do { + assumed = old; + if (__uint_as_float(old) >= val) break; + + old = atomicCAS((unsigned int *)addr, assumed, __float_as_uint(val)); + } while (assumed != old); + + return __uint_as_float(old); +} + +/** + @brief Implementation of single-precision atomic max specialized + for positive-definite numbers. Here we take advantage of the + property that when positive floating point numbers are + reinterpretted as unsigned integers, they have the same unique + sorted order. + + @param addr Address that stores the atomic variable to be updated + @param val Value to be added to the atomic +*/ +static inline float atomicAbsMax(float *addr, float val) +{ + uint32_t val_ = __float_as_uint(val); + uint32_t *addr_ = reinterpret_cast(addr); + return atomicMax(addr_, val_); +} + +/** + @brief atomic_fetch_add function performs similarly as atomic_ref::fetch_add + @param[in,out] addr The memory address of the variable we are + updating atomically + @param[in] val The value we summing to the value at addr + */ +template inline void atomic_fetch_add(T *addr, T val) { atomicAdd(addr, val); } + +#if 0 +template void atomic_fetch_add(vector_type *addr, vector_type val) +{ + for (int i = 0; i < n; i++) atomic_fetch_add(&(*addr)[i], val[i]); +} +#endif + +/** + @brief atomic_fetch_max function that does an atomic max. + @param[in,out] addr The memory address of the variable we are + updating atomically + @param[in] val The value we are comparing against. Must be + positive valued else result is undefined. + */ +template inline void atomic_fetch_abs_max(T *addr, T val) { atomicAbsMax(addr, val); } diff --git a/include/targets/sycl/atomic_helper.h b/include/targets/sycl/atomic_helper.h new file mode 100644 index 0000000000..2c7eb57536 --- /dev/null +++ b/include/targets/sycl/atomic_helper.h @@ -0,0 +1,109 @@ +#pragma once + +#include + +/** + @file atomic_helper.h + + @section Provides definitions of atomic functions that are used in QUDA. + */ + +inline constexpr auto mo = sycl::memory_order::relaxed; +// inline constexpr auto mo = sycl::ext::oneapi::memory_order::acq_rel; +// inline constexpr auto mo = memory_order::seq_cst; +// inline constexpr auto mo = sycl::memory_order::acq_rel; + +// inline constexpr auto ms = sycl::memory_scope::system; +inline constexpr auto ms = sycl::memory_scope::device; +inline constexpr auto msg = sycl::memory_scope::work_group; + +// inline constexpr auto as = sycl::access::address_space::generic_space; +inline constexpr auto as = sycl::access::address_space::global_space; +inline constexpr auto asl = sycl::access::address_space::local_space; + +// using atomicRef = sycl::ext::oneapi::atomic_ref; +template using atomicRef = sycl::atomic_ref; +template using atomicRefL = sycl::atomic_ref; + +template static inline atomicRef makeAtomicRef(T *address) { return atomicRef(*address); } + +template static inline atomicRefL makeAtomicRefL(T *address) { return atomicRefL(*address); } + +#if 0 +using lfloat = std::remove_pointer_t>().get())>; +using ldouble = std::remove_pointer_t>().get())>; + +static inline atomicRefL makeAtomicRef(lfloat *address) { + return atomicRefL(*address); +} + +static inline atomicRefL makeAtomicRef(ldouble *address) { + return atomicRefL(*address); +} + +static inline atomicRefL makeAtomicRefL(lfloat *address) { + return atomicRefL(*address); +} + +static inline atomicRefL makeAtomicRefL(ldouble *address) { + return atomicRefL(*address); +} +#endif + +static inline uint __float_as_uint(float x) { return *reinterpret_cast(&x); } + +#if 0 +static inline int atomicAdd(int *address, int val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.fetch_add(val); + return old; +} + +static inline unsigned int atomicAdd(unsigned int *address, unsigned int val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.fetch_add(val); + return old; +} + +static inline float atomicAdd(float *address, float val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.fetch_add(val); + return old; +} + +static inline double atomicAdd(double *address, double val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.fetch_add(val); + return old; +} +#endif +template static inline int atomicAdd(T *address, U val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.fetch_add(val); + return old; +} + +static inline uint32_t atomicMax(uint32_t *address, uint32_t val) +{ + auto ar = makeAtomicRef(address); + auto old = ar.fetch_max(val); + return old; +} + +template __device__ __host__ inline void atomic_fetch_add(T *addr, U val) +{ + atomicAdd(addr, val); +} + +template __device__ __host__ inline void atomic_add_shared(T *addr, U val) +{ + auto ar = makeAtomicRefL(addr); + ar += val; +} + +#include <../cuda/atomic_helper.h> diff --git a/include/targets/sycl/block_reduce_helper.h b/include/targets/sycl/block_reduce_helper.h new file mode 100644 index 0000000000..c18fe855be --- /dev/null +++ b/include/targets/sycl/block_reduce_helper.h @@ -0,0 +1,223 @@ +#pragma once + +#include +#include +#include +#include +#include + +/** + @file block_reduce_helper.h + + @section This files contains the CUDA device specializations for + warp- and block-level reductions, using the CUB library + */ + +// using namespace quda; + +namespace quda +{ + + /** + @brief The atomic word size we use for a given reduction type. + This type should be lock-free to guarantee correct behaviour on + platforms that are not coherent with respect to the host + */ + template struct atomic_type; + + template <> struct atomic_type { + using type = device_reduce_t; + }; + + template <> struct atomic_type { + using type = float; + }; + + template struct atomic_type>>> { + using type = device_reduce_t; + }; + + template + struct atomic_type, T::N>>>> { + using type = device_reduce_t; + }; + + template struct atomic_type, T::N>>>> { + using type = double; + }; + + template struct atomic_type, T::N>>>> { + using type = float; + }; + + // pre-declaration of warp_reduce that we wish to specialize + template struct warp_reduce; + + /** + @brief SYCL specialization of warp_reduce, utilizing subgroup operations + */ + template <> struct warp_reduce { + + /** + @brief Perform a warp-wide reduction using subgroups + @param[in] value_ thread-local value to be reduced + @param[in] all Whether we want all threads to have visibility + to the result (all = true) or just the first thread in the + warp (all = false). + @param[in] r The reduction operation we want to apply + @return The warp-wide reduced value + */ + template + T inline operator()(const T &value_, bool all, const reducer_t &r, const param_t &) + { + // auto sg = sycl::ext::oneapi::experimental::this_sub_group(); + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + T value = value_; +#pragma unroll + for (int offset = param_t::width / 2; offset >= 1; offset /= 2) { + value = r(value, sycl::shift_group_left(sg, value, offset)); + } + // if (all) value = sycl::select_from_group(sg, value, 0); + if (all) value = sycl::group_broadcast(sg, value); + return value; + } + }; + + // pre-declaration of block_reduce that we wish to specialize + // template struct block_reduce; + template class BlockReduce; + + /** + @brief SYCL specialization of block_reduce, using SYCL group reductions + */ + template struct block_reduceG { + // using dependencies = op_Sequential; + // using dependentOps = KernelOps; + using BlockReduce_t = BlockReduce; + template inline block_reduceG(S &) {}; + /** + @brief Perform a block-wide reduction + @param[in] value_ thread-local value to be reduced + @param[in] async Whether this reduction will be performed + asynchronously with respect to the calling threads + @param[in] batch The batch index of the reduction + @param[in] all Whether we want all threads to have visibility + to the result (all = true) or just the first thread in the + block (all = false) + @param[in] r The reduction operation we want to apply + @return The block-wide reduced value + */ + template inline T apply(const T &value_, bool async, int batch, bool, const reducer_t &r) + { + if (!async) __syncthreads(); // only synchronize if we are not pipelining + const int nbatch = batch_size; + // const int nbatch = std::min(param_t::batch_size, localRangeZ); + auto grp = getGroup(); + T result; + // T result = reducer_t::init(); + for (int i = 0; i < nbatch; i++) { + T in = (i == batch) ? value_ : reducer_t::init(); + T out; + blockReduce(grp, out, in, r); + if (i == batch) result = out; + } + return result; + } + }; + + /** + @brief SYCL specialization of block_reduce, building on the warp_reduce + */ + template struct block_reduceW : SharedMemory { + using Smem = SharedMemory; + using BlockReduce_t = BlockReduce; + template inline block_reduceW(S &ops) : Smem(ops) {}; + + template struct warp_reduce_param { + static constexpr int width = width_; + }; + + /** + @brief Perform a block-wide reduction + @param[in] value_ thread-local value to be reduced + @param[in] async Whether this reduction will be performed + asynchronously with respect to the calling threads + @param[in] batch The batch index of the reduction + @param[in] all Whether we want all threads to have visibility + to the result (all = true) or just the first thread in the + block (all = false) + @param[in] r The reduction operation we want to apply + @return The block-wide reduced value + */ + template inline T apply(const T &value_, bool async, int batch, bool all, const reducer_t &r) + { + constexpr auto max_items = device::max_block_size() / device::warp_size(); + const auto thread_idx = target::thread_idx_linear(); + const auto block_size = target::block_size(); + const auto warp_idx = thread_idx / device::warp_size(); + const auto warp_items = (block_size + device::warp_size() - 1) / device::warp_size(); + + // first do warp reduce + T value = warp_reduce()(value_, false, r, warp_reduce_param()); + + if (!all && warp_items == 1) return value; // short circuit for single warp CTA + + // now do reduction between warps + if (!async) __syncthreads(); // only synchronize if we are not pipelining + + auto storage = Smem::sharedMem(); + + // if first thread in warp, write result to shared memory + if (thread_idx % device::warp_size() == 0) storage[batch * warp_items + warp_idx] = value; + // blockSync(ops); + __syncthreads(); + + // whether to use the first warp or first thread for the final reduction + constexpr bool final_warp_reduction = true; + + if constexpr (final_warp_reduction) { // first warp completes the reduction (requires first warp is full) + if (warp_idx == 0) { + if constexpr (max_items > device::warp_size()) { // never true for max block size 1024, warp = 32 + value = r.init(); + for (auto i = thread_idx; i < warp_items; i += device::warp_size()) + value = r(storage[batch * warp_items + i], value); + } else { // optimized path where we know the final reduction will fit in a warp + value = thread_idx < warp_items ? storage[batch * warp_items + thread_idx] : r.init(); + } + value = warp_reduce()(value, false, r, warp_reduce_param()); + } + } else { // first thread completes the reduction + if (thread_idx == 0) { + for (unsigned int i = 1; i < warp_items; i++) value = r(storage[batch * warp_items + i], value); + } + } + + if (all) { + if (thread_idx == 0) storage[batch * warp_items + 0] = value; + // blockSync(ops); + __syncthreads(); + value = storage[batch * warp_items + 0]; + } + + return value; + } + }; + + // template using block_reduce = block_reduceG; + template using block_reduce = block_reduceW; + +} // namespace quda + +#include "../generic/block_reduce_helper.h" + +namespace quda +{ + template + static constexpr bool needsFullBlockImpl> = true; + template + static constexpr bool needsSharedMemImpl> = true; +} // namespace quda + +static_assert(needsFullBlock>> == true); +static_assert(BlockReduce::shared_mem_size(dim3 {8, 8, 8}) > 0); +static_assert(needsSharedMem>> == true); diff --git a/include/targets/sycl/block_reduce_helper_rog.h b/include/targets/sycl/block_reduce_helper_rog.h new file mode 100644 index 0000000000..25c4f13a95 --- /dev/null +++ b/include/targets/sycl/block_reduce_helper_rog.h @@ -0,0 +1,378 @@ +#pragma once + +#include +#include +#include + +/** + @file block_reduce_helper.h + + @section This files contains the SYCL implementations + for warp- and block-level reductions. + */ + +using namespace quda; + +namespace quda +{ + +#if 0 + /** + @brief warp_reduce_param is used as a container for passing + non-type parameters to specialize warp_reduce through the + target::dispatch + @tparam width The number of logical threads taking part in the warp reduction + */ + template struct warp_reduce_param { + static_assert(width_ <= device::warp_size(), "WarpReduce logical width must not be greater than the warp size"); + static constexpr int width = width_; + }; + + /** + @brief block_reduce_param is used as a container for passing + non-type parameters to specialize block_reduce through the + target::dispatch + @tparam block_size_x_ The number of threads in the x dimension + @tparam block_size_y_ The number of threads in the y dimension + @tparam block_size_z_ The number of threads in the z dimension + @tparam batched Whether this is a batched reduction or not. If + batched, then the block_size_z_ parameter is set equal to the + batch size. + */ + template struct block_reduce_param { + static constexpr int block_size_x = block_size_x_; + static constexpr int block_size_y = block_size_y_; + static constexpr int block_size_z = !batched ? block_size_z_ : 1; + static constexpr int batch_size = !batched ? 1 : block_size_z_; + }; + + /** + @brief Dummy generic implementation of warp_reduce + */ + template struct warp_reduce { + template T operator()(const T &value, bool, reducer_t, param_t) + { + return value; + } + }; + + /** + @brief Dummy generic implementation of block_reduce + */ + template struct block_reduce { + template + T operator()(const T &value, bool, int, bool, reducer_t, param_t) + { + return value; + } + }; +#endif + + /** + @brief WarpReduce provides a generic interface for performing + perform reductions at the warp or sub-warp level + @tparam T The type of the value that we are reducing + @tparam width The number of logical threads taking part in the warp reduction + */ + template class WarpReduce + { + static_assert(width <= device::warp_size(), "WarpReduce logical width must not be greater than the warp size"); + // using param_t = warp_reduce_param; + // const nreduce = device::warp_size() / width; + + public: + constexpr WarpReduce() { } + + /** + @brief Perform a warp-wide sum reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in logical thread 0 only) + */ + inline T Sum(const T &value) + { + // static const __SYCL_CONSTANT_AS char format[] = "WarpReduce::Sum unimplemented\n"; + // sycl::ext::oneapi::experimental::printf(format); + return value; + // return target::dispatch(value, false, quda::plus(), param_t()); + } + + /** + @brief Perform a warp-wide sum reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in all threads within the logical warp) + */ + inline T AllSum(const T &value) + { + // static const __SYCL_CONSTANT_AS char format[] = "WarpReduce::AllSum unimplemented\n"; + // sycl::ext::oneapi::experimental::printf(format); + return value; + // return target::dispatch(value, true, quda::plus(), param_t()); + } + + /** + @brief Perform a warp-wide max reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in logical thread 0 only) + */ + inline T Max(const T &value) + { + // static const __SYCL_CONSTANT_AS char format[] = "WarpReduce::Max unimplemented\n"; + // sycl::ext::oneapi::experimental::printf(format); + return value; + // return target::dispatch(value, false, quda::maximum(), param_t()); + } + + /** + @brief Perform a warp-wide max reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in all threads within the logical warp) + */ + inline T AllMax(const T &value) + { + // static const __SYCL_CONSTANT_AS char format[] = "WarpReduce::AllMax unimplemented\n"; + // sycl::ext::oneapi::experimental::printf(format); + return value; + // return target::dispatch(value, true, quda::maximum(), param_t()); + } + + /** + @brief Perform a warp-wide min reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in logical thread 0 only) + */ + inline T Min(const T &value) + { + // static const __SYCL_CONSTANT_AS char format[] = "WarpReduce::Min unimplemented\n"; + // sycl::ext::oneapi::experimental::printf(format); + return value; + // return target::dispatch(value, false, quda::minimum(), param_t()); + } + + /** + @brief Perform a warp-wide min reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in all threads within the logical warp) + */ + inline T AllMin(const T &value) + { + // static const __SYCL_CONSTANT_AS char format[] = "WarpReduce::AllMin unimplemented\n"; + // sycl::ext::oneapi::experimental::printf(format); + return value; + // return target::dispatch(value, true, quda::minimum(), param_t()); + } + }; + + /** + @brief BlockReduce provides a generic interface for performing + reductions at the block level + @tparam T The type of the value that we are reducing + @tparam block_dim The number of thread block dimensions we are reducing + @tparam batch_size Batch size of the reduction. Threads will be + ordered such that batch size is the slowest running index. + */ + template class BlockReduce + { + static constexpr int batch_size = std::max(batch_size_, 1); + const int nbatch = batch_size_ != 0 ? batch_size_ : localRangeZ; + const int batch; + + public: + constexpr BlockReduce(int batch = 0) : batch(batch) { } + + /** + @brief Perform a block-wide sum reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in logical thread 0 only) + */ + template inline T Sum(const T &value) + { + if (!async) __syncthreads(); // only synchronize if we are not pipelining + auto grp = getGroup(); +#if 1 + T result; + // for(int i=0; i(); + T out; + blockReduceSum(grp, out, in); + if (i == batch) result = out; + } + return result; +#else + using atype = T[512]; // FIXME + auto mem0 = sycl::ext::oneapi::group_local_memory_for_overwrite(grp); + auto mem = *mem0.get(); + auto r0 = localRangeX; + auto r1 = localRangeY; + auto r2 = localRangeZ; + auto i0 = localIdX; + auto i1 = localIdY; + auto i2 = localIdZ; + auto r = r0 * r1; + auto i = i1 * r0 + i0; + if (i2 * r + i < 512) { mem[i2 * r + i] = value; } + group_barrier(grp); + for (int s = 1; s < r; s *= 2) { + int a = 2 * s * i; + int as = a + s; + if (as < r) { + if (i2 * r + as < 512) { mem[i2 * r + a] = mem[i2 * r + a] + mem[i2 * r + as]; } + } + group_barrier(grp); + } + return mem[0]; +#endif + } + + /** + @brief Perform a block-wide sum reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in all threads in the block) + */ + template __device__ __host__ inline T AllSum(const T &value) + { + static_assert(batch_size == 1, "Cannot do AllSum with batch_size > 1"); + if (!async) __syncthreads(); // only synchronize if we are not pipelining + auto grp = getGroup(); + T result; + blockReduceSum(grp, result, value); + return result; + } + + /** + @brief Perform a block-wide max reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in logical thread 0 only) + */ + template __device__ __host__ inline T Max(const T &value) + { + static_assert(batch_size == 1, "Cannot do Max with batch_size > 1"); + if (!async) __syncthreads(); // only synchronize if we are not pipelining + auto grp = getGroup(); + T result; + blockReduceMax(grp, result, value); + return result; + } + + /** + @brief Perform a block-wide max reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in all threads in the block) + */ + template __device__ __host__ inline T AllMax(const T &value) + { + static_assert(batch_size == 1, "Cannot do AllMax with batch_size > 1"); + if (!async) __syncthreads(); // only synchronize if we are not pipelining + auto grp = getGroup(); + T result; + blockReduceMax(grp, result, value); + return result; + } + + /** + @brief Perform a block-wide min reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in logical thread 0 only) + */ + template __device__ __host__ inline T Min(const T &value) + { + static_assert(batch_size == 1, "Cannot do Min with batch_size > 1"); + if (!async) __syncthreads(); // only synchronize if we are not pipelining + auto grp = getGroup(); + T result; + blockReduceMin(grp, result, value); + return result; + } + + /** + @brief Perform a block-wide min reduction + @param[in] value Thread-local value to be reduced + @return Reduced value (defined in all threads in the block) + */ + template __device__ __host__ inline T AllMin(const T &value) + { + static_assert(batch_size == 1, "Cannot do AllMin with batch_size > 1"); + if (!async) __syncthreads(); // only synchronize if we are not pipelining + auto grp = getGroup(); + T result; + blockReduceMin(grp, result, value); + return result; + } + + /** + @brief Perform a block-wide custom reduction + @param[in] value Thread-local value to be reduced + @param[in] r The reduction operation we want to apply + @return Reduced value (defined in logical thread 0 only) + */ +#if 0 + template + inline T + ReduceNotSum(const T &value, const quda::maximum &r) + { + return Max(value); + } + + template + inline T + ReduceNotSum(const T &value, const quda::minimum &r) + { + return Min(value); + } + + template + inline std::enable_if_t + Reduce(const T &value, const reducer_t &r) + { + return ReduceNotSum(value, typename reducer_t::reducer_t()); + } + + template + inline std::enable_if_t + Reduce(const T &value, const reducer_t &r) + { + return Sum(value); + } +#endif + + template + inline std::enable_if_t>, T> Reduce(const T &value, + const R &) + { + return Sum(value); + } + + template + inline std::enable_if_t>, T> + Reduce(const T &value, const R &) + { + return Max(value); + } + + template + inline std::enable_if_t>, T> + Reduce(const T &value, const R &) + { + return Min(value); + } + +#if 0 + /** + @brief Perform a block-wide custom reduction + @param[in] value Thread-local value to be reduced + @param[in] r The reduction operation we want to apply + @return Reduced value (defined in all threads in the block) + */ + template + inline T AllReduce(const T &value, const R &r) + { + static_assert(batch_size == 1, "Cannot do AllReduce with batch_size > 1"); + auto grp = getGroup(); + T result; + blockReduce(grp, result, value, r); // FIXME: not used + return result; + } +#endif + }; + +} // namespace quda diff --git a/include/targets/sycl/block_reduction_kernel.h b/include/targets/sycl/block_reduction_kernel.h new file mode 100644 index 0000000000..ffe46233f6 --- /dev/null +++ b/include/targets/sycl/block_reduction_kernel.h @@ -0,0 +1,197 @@ +#pragma once +#include +#include +#include + +namespace quda +{ + + /** + @brief This helper function swizzles the block index through + mapping the block index onto a matrix and tranposing it. This is + done to potentially increase the cache utilization. Requires + that the argument class has a member parameter "swizzle" which + determines if we are swizzling and a parameter "swizzle_factor" + which is the effective matrix dimension that we are tranposing in + this mapping. + + Specifically, the thread block id is remapped by + transposing its coordinates: if the original order can be + parameterized by + + blockIdx.x = j * swizzle + i, + + then the new order is + + block_idx = i * (gridDim.x / swizzle) + j + + We need to factor out any remainder and leave this in original + ordering. + + @param arg Kernel argument struct + @return Swizzled block index + */ + template int virtual_block_idx(const Arg &arg, const sycl::nd_item<3> &) + { + int block_idx = groupIdX; + if (arg.swizzle) { + // the portion of the grid that is exactly divisible by the number of SMs + // const int gridp = gridDim.x - gridDim.x % arg.swizzle_factor; + const int ngrp = groupRangeX; + const int gridp = ngrp - ngrp % arg.swizzle_factor; + + // block_idx = blockIdx.x; + // if (blockIdx.x < gridp) { + if (block_idx < gridp) { + // this is the portion of the block that we are going to transpose + // const int i = blockIdx.x % arg.swizzle_factor; + // const int j = blockIdx.x / arg.swizzle_factor; + const int i = block_idx % arg.swizzle_factor; + const int j = block_idx / arg.swizzle_factor; + + // transpose the coordinates + block_idx = i * (gridp / arg.swizzle_factor) + j; + } + } + return block_idx; + } + + /** + @brief This class is derived from the arg class that the functor + creates and curries in the block size. This allows the block + size to be set statically at launch time in the actual argument + class that is passed to the kernel. + + @tparam block_size x-dimension block-size + @param[in] arg Kernel argument + */ + template struct BlockKernelArg : Arg_ { + using Arg = Arg_; + static constexpr unsigned int block_size = block_size_; + BlockKernelArg(const Arg &arg) : Arg(arg) { } + }; + + /** + @brief BlockKernel2D_impl is the implementation of the Generic + block kernel. Here, we split the block (CTA) and thread indices + and pass them separately to the transform functor. The x thread + dimension is templated (Arg::block_size), e.g., for efficient + reductions. + + @tparam Functor Kernel functor that defines the kernel + @tparam Arg Kernel argument struct that set any required meta + data for the kernel + @param[in] arg Kernel argument + */ + template