diff --git a/include/blas_helper.cuh b/include/blas_helper.cuh index 806eef5f5e..aab6d6ec9a 100644 --- a/include/blas_helper.cuh +++ b/include/blas_helper.cuh @@ -79,6 +79,18 @@ namespace quda template <> struct VectorType { using type = array; }; + template <> struct VectorType { + using type = array; + }; + template <> struct VectorType { + using type = array; + }; + template <> struct VectorType { + using type = array; + }; + template <> struct VectorType { + using type = array; + }; template <> struct VectorType { using type = array; }; @@ -331,7 +343,7 @@ namespace quda template <> constexpr int n_vector(int nSpin, int site_unroll) { if (site_unroll) - return nSpin == 4 ? colorspinor::get_vector_order(24) : colorspinor::get_vector_order(6); + return colorspinor::get_vector_order(nSpin * 6); else return colorspinor::get_vector_order(4); } @@ -339,7 +351,7 @@ namespace quda template <> constexpr int n_vector(int nSpin, int site_unroll) { if (site_unroll) - return nSpin == 4 ? colorspinor::get_vector_order(24) : colorspinor::get_vector_order(6); + return colorspinor::get_vector_order(nSpin * 6); else return colorspinor::get_vector_order(8); } @@ -347,7 +359,7 @@ namespace quda template <> constexpr int n_vector(int nSpin, int site_unroll) { if (site_unroll) - return nSpin == 4 ? colorspinor::get_vector_order(24) : colorspinor::get_vector_order(6); + return colorspinor::get_vector_order(nSpin * 6); else return colorspinor::get_vector_order(16); } @@ -355,7 +367,7 @@ namespace quda template <> constexpr int n_vector(int nSpin, int site_unroll) { if (site_unroll) - return nSpin == 4 ? colorspinor::get_vector_order(24) : colorspinor::get_vector_order(6); + return colorspinor::get_vector_order(nSpin * 6); else return colorspinor::get_vector_order(16); } @@ -385,13 +397,18 @@ namespace quda constexpr void instantiate(const T &a, const T &b, const T &c, V &x_, Args &&... args) { unwrap_t &x(x_); - if (x.Nspin() == 4 || x.Nspin() == 2) { - if constexpr (is_enabled_spin(2) || is_enabled_spin(4)) { - // Nspin-2 takes Nspin-4 path here, and we check for this later + if (x.Nspin() == 4) { + if constexpr (is_enabled_spin(4)) { Blas(a, b, c, x, args...); } else { errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); } + } else if (x.Nspin() == 2) { + if constexpr (is_enabled_spin(2)) { + Blas(a, b, c, x, args...); + } else { + errorQuda("blas has not been built for Nspin=%d fields", x.Nspin()); + } } else { if constexpr (is_enabled_spin(1)) { Blas(a, b, c, x, args...); diff --git a/include/color_spinor_field.h b/include/color_spinor_field.h index 60521dba72..d5f429f418 100644 --- a/include/color_spinor_field.h +++ b/include/color_spinor_field.h @@ -102,7 +102,7 @@ namespace quda struct ColorSpinorParam : public LatticeFieldParam { int nColor = 0; // Number of colors of the field - int nSpin = 0; // =1 for staggered, =2 for coarse Dslash, =4 for 4d spinor + int nSpin = 0; // =1 for staggered, =2 for coarse Dslash and chiral overlap Dslash, =4 for 4d spinor int nVec = 1; // number of packed vectors (for multigrid transfer operator) int nVec_actual = 1; // The actual number of packed vectors (that are not zero padded) @@ -1084,6 +1084,39 @@ namespace quda */ void spinorDistanceReweight(ColorSpinorField &src, double alpha0, int t0); + /** + @brief Reconstruct a chiral spinor into a full spinor + @param[out] dst The reconstructed full spinor nSpin = 4 + @param[in] src The chiral spinor nSpin = 2 + @param[in] chirality The chirality of the reconstruction + */ + void spinorChiralReconstruct(ColorSpinorField &dst, const ColorSpinorField &src, QudaChirality chirality); + + /** + @brief Reconstruct two chiral spinors into a full spinor + @param[out] dst The reconstructed full spinor nSpin = 4 + @param[in] src_left The left chirality part nSpin = 2 + @param[in] src_right The right chirality part nSpin = 2 + */ + void spinorChiralReconstruct(ColorSpinorField &dst, const ColorSpinorField &src_left, + const ColorSpinorField &src_right); + + /** + @brief Project a full spinor to a chiral spinor + @param[out] dst The projected chiral spinor nSpin = 2 + @param[in] src The full spinor nSpin = 4 + @param[in] chirality The chirality of the projection + */ + void spinorChiralProject(ColorSpinorField &dst, const ColorSpinorField &src, QudaChirality chirality); + + /** + @brief Project a full spinor to two chiral spinors + @param[out] dst_left The projected left chirality part nSpin = 2 + @param[out] dst_right The projected left chirality part nSpin = 2 + @param[in] src The full spinor nSpin = 4 + */ + void spinorChiralProject(ColorSpinorField &dst_left, ColorSpinorField &dst_right, const ColorSpinorField &src); + /** @brief Helper function for determining if the spin of the fields is the same. @param[in] a Input field diff --git a/include/dirac_quda.h b/include/dirac_quda.h index 7372c956e3..00ab6ab2a8 100644 --- a/include/dirac_quda.h +++ b/include/dirac_quda.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace quda { @@ -67,6 +68,8 @@ namespace quda { bool use_mobius_fused_kernel; // Whether or not use fused kernels for Mobius + OverlapKernel *overlap_kernel; + double distance_pc_alpha0; // used by distance preconditioning int distance_pc_t0; // used by distance preconditioning @@ -149,6 +152,7 @@ namespace quda { class DiracMMdag; class DiracMdag; class DiracG5M; + class DiracMdagMChiral; //Forward declaration of multigrid Transfer class class Transfer; @@ -162,6 +166,7 @@ namespace quda { friend class DiracMMdag; friend class DiracMdag; friend class DiracG5M; + friend class DiracMdagMChiral; protected: GaugeField *gauge; @@ -350,6 +355,14 @@ namespace quda { */ virtual void MMdag(cvector_ref &out, cvector_ref &in) const; + /** + @brief Apply MdagM on single chirality + */ + virtual void MdagMChiral(cvector_ref &, cvector_ref &, QudaChirality) const + { + errorQuda("Not implemented!"); + } + /** @brief Prepare the source and solution vectors for solving given the solution type @param[out] out Prepared solution vectors @@ -1415,6 +1428,47 @@ namespace quda { virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; }; + // Full overlap + class DiracOverlap : public Dirac + { + + protected: + OverlapKernel *overlap_kernel; + + public: + DiracOverlap(const DiracParam ¶m); + DiracOverlap(const DiracOverlap &dirac); + virtual ~DiracOverlap(); + DiracOverlap &operator=(const DiracOverlap &dirac); + + virtual void Dslash(cvector_ref &out, cvector_ref &in, + QudaParity parity) const override; + virtual void DslashXpay(cvector_ref &out, cvector_ref &in, + QudaParity parity, cvector_ref &x, double k) const override; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagMChiral(cvector_ref &out, cvector_ref &in, + QudaChirality chirality) const override; + + virtual void prepare(cvector_ref &out, cvector_ref &in, + cvector_ref &x, cvector_ref &b, + const QudaSolutionType solType) const override; + virtual void reconstruct(cvector_ref &x, cvector_ref &b, + const QudaSolutionType solType) const override; + + virtual int getStencilSteps() const override { return 2 * (overlap_kernel->remez_order[0] + 1) + 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_OVERLAP_DIRAC; } + + /** + @brief If managed memory and prefetch is enabled, prefetch + all relevant memory fields (gauge, clover, temporary spinors) + to the CPU or GPU as requested + @param[in] mem_space Memory space we are prefetching to + @param[in] stream Which stream to run the prefetch in (default 0) + */ + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; + }; + // Full staggered class DiracStaggered : public Dirac { @@ -2508,6 +2562,7 @@ namespace quda { case QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC: case QUDA_TWISTED_MASS_DIRAC: case QUDA_TWISTED_CLOVER_DIRAC: + case QUDA_OVERLAP_DIRAC: // while the twisted ops don't have a Hermitian indefinite spectrum, they // do have a spectrum of the form (real) + i mu gamma5(vec, vec); @@ -2593,6 +2648,8 @@ namespace quda { || dirac_type == QUDA_GAUGE_COVDEV_DIRAC) return true; + if (dirac_type == QUDA_WILSON_DIRAC || dirac_type == QUDA_CLOVER_DIRAC) return true; + // subtle: odd operator gets a minus sign if ((dirac_type == QUDA_STAGGEREDPC_DIRAC || dirac_type == QUDA_ASQTADPC_DIRAC) && (pc_type == QUDA_MATPC_EVEN_EVEN || pc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC)) @@ -2602,6 +2659,46 @@ namespace quda { } }; + /** + Gloms onto a DiracMatrix and provides an operator() for its MdagMChiral method + */ + class DiracMdagMChiral : public DiracMatrix + { + protected: + QudaChirality chirality; + + public: + DiracMdagMChiral(const Dirac &d) : DiracMatrix(d) { } + DiracMdagMChiral(const Dirac *d) : DiracMatrix(d) { } + + /** + @brief Multi-RHS operator application. + @param[out] out The vector of output fields + @param[in] in The vector of input fields + */ + void operator()(cvector_ref &out, cvector_ref &in) const override + { + dirac->MdagMChiral(out, in, chirality); + if (shift != 0.0) blas::axpy(shift, in, out); + } + + int getStencilSteps() const override + { + if (dirac->getDiracType() == QUDA_OVERLAP_DIRAC) { + return dirac->getStencilSteps(); // P M^dag M P == P M P for overlap chiral fermion + } else { + return dirac->getStencilSteps() * 2; // 2 for M and M dagger + } + } + + /** + @brief return if the operator is HPD + */ + virtual bool hermitian() const override { return true; } + + void setChirality(QudaChirality chirality_in) { chirality = chirality_in; } + }; + /** * Create the Dirac operator. By default, we also create operators with possibly different * precisions: Sloppy, and Preconditioner. diff --git a/include/enum_quda.h b/include/enum_quda.h index 1de9c3be35..f37674db51 100644 --- a/include/enum_quda.h +++ b/include/enum_quda.h @@ -96,6 +96,7 @@ typedef enum QudaDslashType_s { QUDA_DOMAIN_WALL_4D_DSLASH, QUDA_MOBIUS_DWF_DSLASH, QUDA_MOBIUS_DWF_EOFA_DSLASH, + QUDA_OVERLAP_DSLASH, QUDA_STAGGERED_DSLASH, QUDA_ASQTAD_DSLASH, QUDA_TWISTED_MASS_DSLASH, @@ -170,8 +171,10 @@ typedef enum QudaSolveType_s { QUDA_NORMOP_PC_SOLVE, QUDA_NORMERR_SOLVE, QUDA_NORMERR_PC_SOLVE, - QUDA_NORMEQ_SOLVE = QUDA_NORMOP_SOLVE, // deprecated - QUDA_NORMEQ_PC_SOLVE = QUDA_NORMOP_PC_SOLVE, // deprecated + QUDA_NORMERR_CHIRAL_SOLVE, + QUDA_NORMOP_CHIRAL_SOLVE = QUDA_NORMERR_CHIRAL_SOLVE, // P MdagM P == P MMdag P + QUDA_NORMEQ_SOLVE = QUDA_NORMOP_SOLVE, // deprecated + QUDA_NORMEQ_PC_SOLVE = QUDA_NORMOP_PC_SOLVE, // deprecated QUDA_INVALID_SOLVE = QUDA_INVALID_ENUM } QudaSolveType; @@ -305,6 +308,7 @@ typedef enum QudaDiracType_s { QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC, QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC, QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC, + QUDA_OVERLAP_DIRAC, QUDA_STAGGERED_DIRAC, QUDA_STAGGEREDPC_DIRAC, QUDA_STAGGEREDKD_DIRAC, @@ -641,6 +645,12 @@ typedef enum QudaExtLibType_s { QUDA_EXTLIB_INVALID = QUDA_INVALID_ENUM } QudaExtLibType; +typedef enum QudaChirality_s { + QUDA_LEFT_CHIRALITY = -1, // (1 - \gamma_5) / 2 + QUDA_RIGHT_CHIRALITY = +1, // (1 + \gamma_5) / 2 + QUDA_INVALID_CHIRALITY = QUDA_INVALID_ENUM +} QudaChirality; + typedef enum QudaDDType_s { QUDA_DD_NO, QUDA_DD_RED_BLACK, QUDA_DD_INVALID = QUDA_INVALID_ENUM } QudaDDType; typedef enum QudaWFlowStepType_s { diff --git a/include/enum_quda_fortran.h b/include/enum_quda_fortran.h index 33fa5a9ad8..70e2af9723 100644 --- a/include/enum_quda_fortran.h +++ b/include/enum_quda_fortran.h @@ -81,12 +81,13 @@ #define QUDA_DOMAIN_WALL_4D_DSLASH 4 #define QUDA_MOBIUS_DWF_DSLASH 5 #define QUDA_MOBIUS_DWF_EOFA_DSLASH 6 -#define QUDA_STAGGERED_DSLASH 7 -#define QUDA_ASQTAD_DSLASH 8 -#define QUDA_TWISTED_MASS_DSLASH 9 -#define QUDA_TWISTED_CLOVER_DSLASH 10 -#define QUDA_LAPLACE_DSLASH 11 -#define QUDA_COVDEV_DSLASH 12 +#define QUDA_OVERLAP_DSLASH 7 +#define QUDA_STAGGERED_DSLASH 8 +#define QUDA_ASQTAD_DSLASH 9 +#define QUDA_TWISTED_MASS_DSLASH 10 +#define QUDA_TWISTED_CLOVER_DSLASH 11 +#define QUDA_LAPLACE_DSLASH 12 +#define QUDA_COVDEV_DSLASH 13 #define QUDA_INVALID_DSLASH QUDA_INVALID_ENUM #define QudaInverterType integer(4) @@ -148,6 +149,8 @@ #define QUDA_NORMOP_PC_SOLVE 3 #define QUDA_NORMERR_SOLVE 4 #define QUDA_NORMERR_PC_SOLVE 5 +#define QUDA_NORMERR_CHIRAL_SOLVE 6 +#define QUDA_NORMOP_CHIRAL_SOLVE QUDA_NORMERR_CHIRAL_SOLVE // P M^\dag M P == P M M^\dag P #define QUDA_NORMEQ_SOLVE QUDA_NORMOP_SOLVE // deprecated #define QUDA_NORMEQ_PC_SOLVE QUDA_NORMOP_PC_SOLVE // deprecated #define QUDA_INVALID_SOLVE QUDA_INVALID_ENUM @@ -272,21 +275,22 @@ #define QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC 11 #define QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC 12 #define QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC 13 -#define QUDA_STAGGERED_DIRAC 14 -#define QUDA_STAGGEREDPC_DIRAC 15 -#define QUDA_STAGGEREDKD_DIRAC 16 -#define QUDA_ASQTAD_DIRAC 17 -#define QUDA_ASQTADPC_DIRAC 18 -#define QUDA_ASQTADKD_DIRAC 19 -#define QUDA_TWISTED_MASS_DIRAC 20 -#define QUDA_TWISTED_MASSPC_DIRAC 21 -#define QUDA_TWISTED_CLOVER_DIRAC 22 -#define QUDA_TWISTED_CLOVERPC_DIRAC 23 -#define QUDA_COARSE_DIRAC 24 -#define QUDA_COARSEPC_DIRAC 25 -#define QUDA_GAUGE_LAPLACE_DIRAC 26 -#define QUDA_GAUGE_LAPLACEPC_DIRAC 27 -#define QUDA_GAUGE_COVDEV_DIRAC 28 +#define QUDA_OVERLAP_DIRAC 14 +#define QUDA_STAGGERED_DIRAC 15 +#define QUDA_STAGGEREDPC_DIRAC 16 +#define QUDA_STAGGEREDKD_DIRAC 17 +#define QUDA_ASQTAD_DIRAC 18 +#define QUDA_ASQTADPC_DIRAC 19 +#define QUDA_ASQTADKD_DIRAC 20 +#define QUDA_TWISTED_MASS_DIRAC 21 +#define QUDA_TWISTED_MASSPC_DIRAC 22 +#define QUDA_TWISTED_CLOVER_DIRAC 23 +#define QUDA_TWISTED_CLOVERPC_DIRAC 24 +#define QUDA_COARSE_DIRAC 25 +#define QUDA_COARSEPC_DIRAC 26 +#define QUDA_GAUGE_LAPLACE_DIRAC 27 +#define QUDA_GAUGE_LAPLACEPC_DIRAC 28 +#define QUDA_GAUGE_COVDEV_DIRAC 29 #define QUDA_INVALID_DIRAC QUDA_INVALID_ENUM ! Where the field is stored @@ -540,12 +544,16 @@ #define QUDA_FERMION_SMEAR_TYPE_WUPPERTAL 1 #define QUDA_FERMION_SMEAR_TYPE_INVALID QUDA_INVALID_ENUM - #define QudaExtLibType integer(4) #define QUDA_CUSOLVE_EXTLIB 0 #define QUDA_EIGEN_EXTLIB 1 #define QUDA_EXTLIB_INVALID QUDA_INVALID_ENUM +#define QudaChirality integer(4) +#define QUDA_LEFT_CHIRALITY -1 +#define QUDA_RIGHT_CHIRALITY +1 +#define QUDA_INVALID_CHIRALITY QUDA_INVALID_ENUM + #define QudaDDType integer(4) #define QUDA_DD_NO 0 #define QUDA_DD_RED_BLACK 1 diff --git a/include/kernels/spinor_chiral_project.cuh b/include/kernels/spinor_chiral_project.cuh new file mode 100644 index 0000000000..e72736355c --- /dev/null +++ b/include/kernels/spinor_chiral_project.cuh @@ -0,0 +1,100 @@ +#include +#include +#include +#include + +namespace quda +{ + using namespace colorspinor; + + template + struct ChiralReconstructSpinorArg : kernel_param<> { + using real = typename mapper::type; + static constexpr int nSpin = 4; + static constexpr int nColor = nColor_; + static constexpr QudaChirality Chirality = Chirality_; + using Vout = typename colorspinor_mapper::type; + using Vin = typename colorspinor_mapper::type; + + Vout out; + const Vin in_left; + const Vin in_right; + ChiralReconstructSpinorArg(ColorSpinorField &out, const ColorSpinorField &in_left, const ColorSpinorField &in_right) : + kernel_param(dim3(out.VolumeCB(), out.SiteSubset(), 1)), out(out), in_left(in_left), in_right(in_right) + { + } + }; + + template struct ChiralReconstructSpinor { + const Arg &arg; + constexpr ChiralReconstructSpinor(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ void operator()(int x_cb, int parity) + { + using real = typename Arg::real; + using Vector = ColorSpinor; + using HalfVector = ColorSpinor; + const real invsqrt2 = (real)(1.0 / sqrt(2.0)); + + Vector out; + HalfVector in; + if constexpr (Arg::Chirality == QUDA_LEFT_CHIRALITY || Arg::Chirality == QUDA_INVALID_CHIRALITY) { + in = arg.in_left(x_cb, parity); + out += in.chiral_reconstruct(1); + } + if constexpr (Arg::Chirality == QUDA_RIGHT_CHIRALITY || Arg::Chirality == QUDA_INVALID_CHIRALITY) { + in = arg.in_right(x_cb, parity); + out += in.chiral_reconstruct(0); + } + out.toNonRel(); + out *= invsqrt2; + arg.out(x_cb, parity) = out; + } + }; + + template struct ChiralProjectSpinorArg : kernel_param<> { + using real = typename mapper::type; + static constexpr int nSpin = 4; + static constexpr int nColor = nColor_; + static constexpr QudaChirality Chirality = Chirality_; + using Vout = typename colorspinor_mapper::type; + using Vin = typename colorspinor_mapper::type; + + Vout out_left; + Vout out_right; + const Vin in; + ChiralProjectSpinorArg(ColorSpinorField &out_left, ColorSpinorField &out_right, const ColorSpinorField &in) : + kernel_param(dim3(in.VolumeCB(), in.SiteSubset(), 1)), out_left(out_left), out_right(out_right), in(in) + { + } + }; + + template struct ChiralProjectSpinor { + const Arg &arg; + constexpr ChiralProjectSpinor(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ void operator()(int x_cb, int parity) + { + using real = typename Arg::real; + using HalfVector = ColorSpinor; + using Vector = ColorSpinor; + const real invsqrt2 = (real)(1.0 / sqrt(2.0)); + + HalfVector out; + Vector in = arg.in(x_cb, parity); + in.toRel(); + in *= invsqrt2; + if constexpr (Arg::Chirality == QUDA_LEFT_CHIRALITY || Arg::Chirality == QUDA_INVALID_CHIRALITY) { + out = in.chiral_project(1); + arg.out_left(x_cb, parity) = out; + } + if constexpr (Arg::Chirality == QUDA_RIGHT_CHIRALITY || Arg::Chirality == QUDA_INVALID_CHIRALITY) { + out = in.chiral_project(0); + arg.out_right(x_cb, parity) = out; + } + } + }; + +} // namespace quda diff --git a/include/overlap_kernel.h b/include/overlap_kernel.h new file mode 100644 index 0000000000..9ecff20986 --- /dev/null +++ b/include/overlap_kernel.h @@ -0,0 +1,31 @@ +/** + @file overlap.h + + @section DESCRIPTION +*/ + +#pragma once + +#include +#include + +namespace quda +{ + struct OverlapKernel { + std::vector evecs; + std::vector evals; + double kappa; + double epsilon; + std::vector remez_tol; + std::vector> remez_coeff; + std::vector remez_order; + + OverlapKernel(std::vector &evecs, const std::vector &evals, double kappa, + const std::vector remez_tol); + OverlapKernel(const OverlapKernel *overlap_kernel, QudaPrecision precision); + ~OverlapKernel() = default; + + inline QudaPrecision Precision() const { return evecs[0].Precision(); } + inline double Kappa() const { return kappa; } + }; +} // namespace quda diff --git a/include/quda.h b/include/quda.h index 13a0a69140..945eb2b4f5 100644 --- a/include/quda.h +++ b/include/quda.h @@ -139,6 +139,9 @@ extern "C" { QudaTwistFlavorType twist_flavor; /**< Twisted mass flavor */ + /** Parameters for overlap fermion */ + double overlap_invsqrt_tol; + int laplace3D; /**< omit this direction from laplace operator: x,y,z,t -> 0,1,2,3 (-1 is full 4D) */ int covdev_mu; /**< Apply forward/backward covariant derivative in direction mu(mu<=3)/mu-4(mu>3) */ @@ -523,6 +526,9 @@ extern "C" { QudaBoolean use_norm_op; QudaBoolean use_pc; + /** Use chiral version of MdagM */ + QudaChirality chirality; + /** Use Eigen routines to eigensolve the upper Hessenberg via QR **/ QudaBoolean use_eigen_qr; @@ -1218,6 +1224,9 @@ extern "C" { */ void eigensolveQuda(void **h_evecs, double_complex *h_evals, QudaEigParam *param); + void loadOverlapQuda(QudaInvertParam *inv_param, QudaEigParam *eig_param); + void freeOverlapQuda(); + /** * Perform the solve, according to the parameters set in param. It * is assumed that the gauge field has already been loaded via diff --git a/include/quda_internal.h b/include/quda_internal.h index c7ba9f0489..6b6795793e 100644 --- a/include/quda_internal.h +++ b/include/quda_internal.h @@ -22,15 +22,15 @@ // these are helper macros used to enable spin-1, spin-2 and spin-4 building blocks as needed #if defined(GPU_WILSON_DIRAC) || defined(GPU_DOMAIN_WALL_DIRAC) || defined(GPU_CLOVER_DIRAC) \ || defined(GPU_TWISTED_MASS_DIRAC) || defined(GPU_TWISTED_CLOVER_DIRAC) || defined(GPU_CLOVER_HASENBUSCH_TWIST) \ - || defined(GPU_COVDEV) || defined(GPU_CONTRACT) + || defined(GPU_LAPLACE) || defined(GPU_COVDEV) || defined(GPU_CONTRACT) #define NSPIN4 #endif -#if defined(GPU_MULTIGRID) +#if defined(GPU_WILSON_DIRAC) || defined(GPU_MULTIGRID) #define NSPIN2 #endif -#if defined(GPU_STAGGERED_DIRAC) || defined(GPU_LAPLACE) +#if defined(GPU_STAGGERED_DIRAC) || defined(GPU_LAPLACE) || defined(GPU_COVDEV) #define NSPIN1 #endif diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 6043cc0250..be0947cab5 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -45,7 +45,7 @@ set (QUDA_OBJS dirac_staggered_kd.cpp dirac_clover_hasenbusch_twist.cpp dirac_improved_staggered.cpp dirac_improved_staggered_kd.cpp dirac_domain_wall.cpp dirac_domain_wall_4d.cpp dirac_mobius.cpp dirac_twisted_clover.cpp - dirac_twisted_mass.cpp + dirac_twisted_mass.cpp dirac_overlap.cpp overlap_kernel.cpp llfat_quda.cu staggered_two_link_quda.cu gauge_force.cu gauge_loop_trace.cu gauge_polyakov_loop.cu gauge_random.cu gauge_noise.cu gauge_field_strength_tensor.cu clover_quda.cu @@ -81,6 +81,7 @@ set (QUDA_OBJS extract_gauge_ghost_extended.cu copy_color_spinor.cpp spin_duplicate.cu spinor_noise.cu spinor_dilute.cu spinor_reweight.cu + spinor_chiral_project.cu copy_color_spinor_dd.cu copy_color_spinor_ds.cu copy_color_spinor_dh.cu copy_color_spinor_dq.cu copy_color_spinor_ss.cu copy_color_spinor_sd.cu diff --git a/lib/blas_quda.cu b/lib/blas_quda.cu index 163aea688b..9f251fe162 100644 --- a/lib/blas_quda.cu +++ b/lib/blas_quda.cu @@ -64,7 +64,8 @@ namespace quda { void apply(const qudaStream_t &stream) override { constexpr bool site_unroll_check = !std::is_same::value || isFixed::value; - if (site_unroll_check && (x.Ncolor() != 3 || x.Nspin() == 2)) + // TODO: Is x.Nspin() == 2 check needed here? + if (site_unroll_check && (x.Ncolor() != 3 && x.Nspin() == 2)) errorQuda("site unroll not supported for nSpin = %d nColor = %d", x.Nspin(), x.Ncolor()); if (location == QUDA_CUDA_FIELD_LOCATION) { @@ -78,7 +79,8 @@ namespace quda { constexpr bool site_unroll = !std::is_same::value || isFixed::value; constexpr int N = n_vector(nSpin, site_unroll); constexpr int Ny = n_vector(nSpin, site_unroll); - constexpr int M = site_unroll ? (nSpin == 4 ? 24 : 6) : N; // real numbers per thread + // TODO: Shall we use n_vector(nSpin, true) here? + constexpr int M = site_unroll ? (nSpin * 6) : N; // real numbers per thread const int threads = x.Length() / (nParity * M); TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); diff --git a/lib/check_params.h b/lib/check_params.h index 6dbdc91dfd..4ff2d278bf 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -236,6 +236,10 @@ void printQudaEigParam(QudaEigParam *param) { P(partfile, QUDA_BOOLEAN_INVALID); #endif +#ifndef CHECK_PARAM + P(chirality, QUDA_INVALID_CHIRALITY); +#endif + // only need to enfore block size checking if doing a block eigen solve #ifdef CHECK_PARAM if (param->eig_type == QUDA_EIG_BLK_TR_LANCZOS) @@ -374,17 +378,19 @@ void printQudaInvertParam(QudaInvertParam *param) { P(evmax, INVALID_DOUBLE); P(tm_rho, 0.0); P(twist_flavor, QUDA_TWIST_INVALID); + P(overlap_invsqrt_tol, INVALID_DOUBLE); P(laplace3D, INVALID_INT); P(covdev_mu, INVALID_INT); #else - // asqtad and domain wall use mass parameterization + // staggered, overlap, and domain wall use mass parameterization if (param->dslash_type == QUDA_STAGGERED_DSLASH || param->dslash_type == QUDA_ASQTAD_DSLASH - || param->dslash_type == QUDA_DOMAIN_WALL_DSLASH || param->dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH - || param->dslash_type == QUDA_MOBIUS_DWF_DSLASH) { + || param->dslash_type == QUDA_OVERLAP_DSLASH || param->dslash_type == QUDA_DOMAIN_WALL_DSLASH + || param->dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH || param->dslash_type == QUDA_MOBIUS_DWF_DSLASH) { P(mass, INVALID_DOUBLE); } else { // Wilson and clover use kappa parameterization P(kappa, INVALID_DOUBLE); } + if (param->dslash_type == QUDA_OVERLAP_DSLASH) { P(overlap_invsqrt_tol, INVALID_DOUBLE); } if (param->dslash_type == QUDA_DOMAIN_WALL_DSLASH || param->dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH || param->dslash_type == QUDA_MOBIUS_DWF_DSLASH ) { diff --git a/lib/clover_field.cpp b/lib/clover_field.cpp index 2f0078aacf..1ffa057db1 100644 --- a/lib/clover_field.cpp +++ b/lib/clover_field.cpp @@ -185,7 +185,8 @@ namespace quda { { LatticeField::setTuningString(); std::stringstream aux_ss; - aux_ss << "vol=" << volume << "precision=" << precision << "Nc=" << nColor << ",order=" << order; + aux_ss << "vol=" << volume << "precision=" << precision << "Nc=" << nColor << ",memory=" << mem_type + << ",order=" << order; if (isNative()) aux_ss << ",N=" << clover::get_vector_order(precision, 128); aux_string = aux_ss.str(); if (aux_string.size() >= TuneKey::aux_n / 2) errorQuda("Aux string too large %lu", aux_string.size()); diff --git a/lib/color_spinor_field.cpp b/lib/color_spinor_field.cpp index 72a3f5f62a..50c9e562a8 100644 --- a/lib/color_spinor_field.cpp +++ b/lib/color_spinor_field.cpp @@ -296,7 +296,7 @@ namespace quda if (init) { std::stringstream aux_ss; aux_ss << "vol=" << volume << ",parity=" << siteSubset << ",precision=" << precision << ",Ns=" << nSpin - << ",Nc=" << nColor << ",order=" << fieldOrder; + << ",Nc=" << nColor << ",memory=" << mem_type << ",order=" << fieldOrder; if (isNative()) aux_ss << ",N=" << colorspinor::get_vector_order(precision, 128); if (nVec > 1) aux_ss << ",nVec=" << nVec; if (twistFlavor != QUDA_TWIST_NO && twistFlavor != QUDA_TWIST_INVALID) aux_ss << ",TwistFlavor=" << twistFlavor; diff --git a/lib/dirac.cpp b/lib/dirac.cpp index a5ec216943..c8e19c06cd 100644 --- a/lib/dirac.cpp +++ b/lib/dirac.cpp @@ -48,6 +48,7 @@ namespace quda { type(dirac.type), halo_precision(dirac.halo_precision), commDim(dirac.commDim), + use_mobius_fused_kernel(dirac.use_mobius_fused_kernel), // TODO: Shall we copy this? distance_pc_alpha0(dirac.distance_pc_alpha0), distance_pc_t0(dirac.distance_pc_t0), profile("Dirac", false) @@ -72,6 +73,7 @@ namespace quda { symmetric = dirac.symmetric; dagger = dirac.dagger; commDim = dirac.commDim; + use_mobius_fused_kernel = dirac.use_mobius_fused_kernel; // TODO: Shall we copy this? distance_pc_alpha0 = dirac.distance_pc_alpha0; distance_pc_t0 = dirac.distance_pc_t0; profile = dirac.profile; @@ -186,6 +188,9 @@ namespace quda { } else if (param.type == QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC) { if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracMobiusEofaPC operator\n"); return new DiracMobiusEofaPC(param); + } else if (param.type == QUDA_OVERLAP_DIRAC) { + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracOverlap operator\n"); + return new DiracOverlap(param); } else if (param.type == QUDA_STAGGERED_DIRAC) { if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracStaggered operator\n"); return new DiracStaggered(param); diff --git a/lib/dirac_overlap.cpp b/lib/dirac_overlap.cpp new file mode 100644 index 0000000000..5f7e32a86d --- /dev/null +++ b/lib/dirac_overlap.cpp @@ -0,0 +1,181 @@ +#include +#include +#include +#include + +namespace quda +{ + /** + * Apply the overlap overlap + * out = D * in + * If m is not zero, then + * out = m * x + (1 - m) * D * in + * D is defined as 0.5 * (1 + \gamma_5 sign(\gamma_5 M)) where M is the Wilson operator + */ + void ApplyOverlap(cvector_ref &out, cvector_ref &in, const GaugeField &U, + OverlapKernel &O, double m, cvector_ref &x, int parity, bool dagger, + const int *comm_override, TimeProfile &profile) + { + auto in_def = getFieldTmp(out); + auto b1 = getFieldTmp(out); + auto b2 = getFieldTmp(out); + auto Mb1 = getFieldTmp(out); + auto Ab1 = getFieldTmp(out); + + cvector_ref &evecs = O.evecs; + cvector &evals = O.evals; + const double remez_order = O.remez_order[0]; + cvector &remez_coeff = O.remez_coeff[0]; + const double lambda_max = (1.0 + 8.0 * O.kappa); + const double epsilon = O.epsilon; + + /** + * Apply (1 - m) * 0.5 directly to the input + */ + if (dagger) { + blas::axy((1 - m) * 0.5, in, out); + gamma5(in_def, out); + } else { + blas::axy((1 - m) * 0.5, in, in_def); + gamma5(out, in_def); + } + + /** + * \gamma_5 sign(\gamma_5 M) for small eigenvalues + * Define the eigenvalues and eigenvectors \gamma_5 M v_i = \lambda_i v_i + * ==> \gamma_5 \sum_i sign(\lambda_i) |v_i> alpha(evecs.size() * in_def.size()); + blas::block::cDotProduct(alpha, evecs, in_def); + for (auto &v : alpha) { v *= -1; } + blas::block::caxpy(alpha, evecs, in_def); + for (size_t i = 0; i < evecs.size(); i++) { + for (size_t j = 0; j < in_def.size(); ++j) { alpha[i * in_def.size() + j] *= -evals[i] / abs(evals[i]); } + } + blas::block::caxpy(alpha, evecs, out); + if (!dagger) { gamma5(out, out); } + + /** + * \gamma_5 sign(\gamma_5 M) for large eigenvalues + * Define the Chebyshev polynomial approximation P(x) ~ x^{-1/2} + * ==> M P(M^\dagger M) + * Here M is the normalized Wilson operator which has the maximum eigenvalue 1 + */ + blas::zero(b1); + blas::zero(b2); + for (int k = remez_order; k >= 0; --k) { + ApplyWilson(Mb1, b1, U, -O.kappa, b1, parity, false, comm_override, profile); + ApplyWilson(Ab1, Mb1, U, -O.kappa, Mb1, parity, true, comm_override, profile); + blas::axpby(-(1.0 + epsilon) / (1.0 - epsilon), b1, 2.0 / (1.0 - epsilon) / (lambda_max * lambda_max), Ab1); + if (k > 0) { + blas::axpbypczw(remez_coeff[k], in_def, 2.0, Ab1, -1.0, b2, b2); + } else { + blas::axpbypczw(remez_coeff[0], in_def, 1.0, Ab1, -1.0, b2, b2); + } + std::swap(b1, b2); + } + ApplyWilson(Mb1, b1, U, -O.kappa, b1, parity, false, comm_override, profile); + if (dagger) { gamma5(Mb1, Mb1); } + if (m == 0.0) { + blas::axpbyz(1.0 / lambda_max, Mb1, 1.0, out, out); + } else { + blas::axpbypczw(m, x, 1.0 / lambda_max, Mb1, 1.0, out, out); + } + } + + DiracOverlap::DiracOverlap(const DiracParam ¶m) : Dirac(param), overlap_kernel(param.overlap_kernel) { } + + DiracOverlap::DiracOverlap(const DiracOverlap &dirac) : Dirac(dirac), overlap_kernel(dirac.overlap_kernel) { } + + DiracOverlap::~DiracOverlap() { } + + DiracOverlap &DiracOverlap::operator=(const DiracOverlap &dirac) + { + if (&dirac != this) { + Dirac::operator=(dirac); + overlap_kernel = dirac.overlap_kernel; + } + return *this; + } + + void DiracOverlap::Dslash(cvector_ref &, cvector_ref &, QudaParity) const + { + errorQuda("The overlap Dirac operator does not have a single parity form"); + } + + void DiracOverlap::DslashXpay(cvector_ref &, cvector_ref &, QudaParity, + cvector_ref &, double) const + { + errorQuda("The overlap Dirac operator does not have a single parity form"); + } + + // Defined as m + (1 - m) D + void DiracOverlap::M(cvector_ref &out, cvector_ref &in) const + { + ApplyOverlap(out, in, *gauge, *overlap_kernel, mass, in, QUDA_INVALID_PARITY, dagger, commDim.data, profile); + } + + // Defined as m^2 + (1 - m^2) DdagD + void DiracOverlap::MdagM(cvector_ref &out, cvector_ref &in) const + { + auto tmp = getFieldTmp(out); + ApplyOverlap(out, in, *gauge, *overlap_kernel, 0.0, in, QUDA_INVALID_PARITY, dagger, commDim.data, profile); + flipDagger(); + ApplyOverlap(out, in, *gauge, *overlap_kernel, mass * mass, in, QUDA_INVALID_PARITY, dagger, commDim.data, profile); + flipDagger(); + } + + // Defined as m^2 + (1 - m^2) P DdagD P where P = (1 +- gamma_5) / 2 + // For overlap dslash P DdagD P = P D P + void DiracOverlap::MdagMChiral(cvector_ref &out, cvector_ref &in, + QudaChirality chirality) const + { + ColorSpinorParam param(in[0]); + param.nSpin = 4; + param.gammaBasis = QUDA_UKQCD_GAMMA_BASIS; + param.mem_type = QUDA_MEMORY_DEVICE; // TODO: Hack for eigensolver in the host memory + param.setPrecision(param.Precision(), param.Precision(), true); + auto in_tmp = getFieldTmp(in.size(), param); + auto out_tmp = getFieldTmp(out.size(), param); + + for (size_t i = 0; i < in.size(); i++) { spinorChiralReconstruct(in_tmp[i], in[i], chirality); } + ApplyOverlap(out_tmp, in_tmp, *gauge, *overlap_kernel, mass * mass, in_tmp, QUDA_INVALID_PARITY, dagger, + commDim.data, profile); + for (size_t i = 0; i < out.size(); i++) { spinorChiralProject(out[i], out_tmp[i], chirality); } + } + + void DiracOverlap::prepare(cvector_ref &out, cvector_ref &in, + cvector_ref &x, cvector_ref &b, + const QudaSolutionType solType) const + { + if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { + errorQuda("Preconditioned solution requires a preconditioned solve_type"); + } + + create_alias(in, b); + create_alias(out, x); + } + + void DiracOverlap::reconstruct(cvector_ref &x, cvector_ref &b, + const QudaSolutionType solType) const + { + if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { return; } + + if (solType == QUDA_MAT_SOLUTION) { + // We actually apply (1 - D) x' + // x = -1 / (1 - m) * b + 1 / (1 - m) * x' + // x' = M^{-1} * b = (m + (1 - m) D)^{-1} * b + blas::axpby(-1.0 / (1.0 - mass), b, 1.0 / (1.0 - mass), x); + } else if (solType == QUDA_MATDAG_MAT_SOLUTION) { + // We actually apply (1 - DdagD) x' + // x = -1 / (1 - m^2) * b + 1 / (1 - m^2) * x' + // x' = (MdagM)^{-1} * b = (m^2 + (1 - m^2) DdagD)^{-1} * b + blas::axpby(-1.0 / (1.0 - mass * mass), b, 1.0 / (1.0 - mass * mass), x); + } + } + + void DiracOverlap::prefetch(QudaFieldLocation mem_space, qudaStream_t stream) const + { + Dirac::prefetch(mem_space, stream); + } +} // namespace quda \ No newline at end of file diff --git a/lib/eigensolve_quda.cpp b/lib/eigensolve_quda.cpp index 9d1123fb3d..23852ab34b 100644 --- a/lib/eigensolve_quda.cpp +++ b/lib/eigensolve_quda.cpp @@ -282,19 +282,30 @@ namespace quda double d2 = 1.0; double d3; - // out = d2 * in + d1 * out - // C_1(x) = x - mat({out.begin(), out.end()}, {in.begin(), in.end()}); - blas::caxpby(d2, in, d1, out); - - if (eig_param->poly_deg == 1) return; + ColorSpinorParam param(in[0]); + bool use_out_as_tmp = param.mem_type == QUDA_MEMORY_DEVICE; + param.mem_type = QUDA_MEMORY_DEVICE; // FIXME: Hack for Ritz vectors on the host memory + std::vector tmp1, tmp2, tmp3; + tmp1.reserve(in.size()); // C_{m-1}(x) + tmp2.reserve(in.size()); // C_{m}(x) + tmp3.reserve(in.size()); // mat*C_{m}(x) + for (auto i = 0u; i < in.size(); i++) { + tmp1.push_back(getFieldTmp(param)); + tmp2.push_back(getFieldTmp(param)); + if (use_out_as_tmp) { + tmp3.push_back(out[i]); + } else { + tmp3.push_back(getFieldTmp(param)); + } + } - // C_0 is the current 'in' vector. - // C_1 is the current 'out' vector. + // Clone 'in' to temporary vector. + blas::copy(tmp1, in); - // Clone 'in' to two temporary vectors. - std::vector tmp1{in.begin(), in.end()}; - std::vector tmp2{out.begin(), out.end()}; + // out = d2 * in + d1 * out + // C_1(x) = x + mat(tmp2, tmp1); + blas::axpbyz(d2, tmp1, d1, tmp2, tmp2); // Using Chebyshev polynomial recursion relation, // C_{m+1}(x) = 2*x*C_{m} - C_{m-1} @@ -302,7 +313,7 @@ namespace quda double sigma_old = sigma1; // construct C_{m+1}(x) - for (int i = 2; i < eig_param->poly_deg; i++) { + for (int i = 1; i < eig_param->poly_deg; i++) { sigma = 1.0 / (2.0 / sigma1 - sigma_old); d1 = 2.0 * sigma / delta; @@ -311,15 +322,19 @@ namespace quda // FIXME - we could introduce a fused mat + blas kernel here, eliminating one temporary // mat*C_{m}(x) - mat(out, tmp2); + mat(tmp3, tmp2); - blas::axpbypczw(d3, tmp1, d2, tmp2, d1, out, tmp1); + blas::axpbypczw(d3, tmp1, d2, tmp2, d1, tmp3, tmp1); std::swap(tmp1, tmp2); sigma_old = sigma; } - for (auto i = 0u; i < in.size(); i++) std::swap(out[i], tmp2[i]); + if (use_out_as_tmp) { + for (auto i = 0u; i < in.size(); i++) std::swap(out[i], tmp2[i]); + } else { + blas::copy(out, tmp2); + } } double EigenSolver::estimateChebyOpMax(ColorSpinorField &out, ColorSpinorField &in) diff --git a/lib/gauge_field.cpp b/lib/gauge_field.cpp index 96e7a51d11..0b6d7202ee 100644 --- a/lib/gauge_field.cpp +++ b/lib/gauge_field.cpp @@ -319,7 +319,7 @@ namespace quda { LatticeField::setTuningString(); std::stringstream aux_ss; aux_ss << "vol=" << volume << ",stride=" << stride << ",precision=" << precision << ",geometry=" << geometry - << ",Nc=" << nColor << ",order=" << order; + << ",Nc=" << nColor << ",memory=" << mem_type << ",order=" << order; if (isNative()) aux_ss << ",N=" << gauge::get_vector_order(precision, 128); if (ghostExchange == QUDA_GHOST_EXCHANGE_EXTENDED) aux_ss << ",r=" << r[0] << r[1] << r[2] << r[3]; aux_string = aux_ss.str(); diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 2270965b4a..d7db9a145f 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -37,6 +37,8 @@ #include #include +#include + #include #include @@ -106,6 +108,12 @@ CloverField *cloverPrecondition = nullptr; CloverField *cloverRefinement = nullptr; CloverField *cloverEigensolver = nullptr; +OverlapKernel *overlapPrecise = nullptr; +OverlapKernel *overlapSloppy = nullptr; +OverlapKernel *overlapPrecondition = nullptr; +OverlapKernel *overlapRefinement = nullptr; +OverlapKernel *overlapEigensolver = nullptr; + GaugeField momResident; GaugeField *extendedGaugeResident = nullptr; @@ -156,6 +164,9 @@ static TimeProfile profileGauge("loadGaugeQuda"); //!< Profile for loadCloverQuda static TimeProfile profileClover("loadCloverQuda"); +//!< Profiler for loadOverlapQuda +static TimeProfile profileOverlap("loadOverlapQuda"); + //!< Profiler for dslashQuda static TimeProfile profileDslash("dslashQuda"); @@ -1133,6 +1144,155 @@ void loadSloppyCloverQuda(const QudaPrecision *prec) } +void freeSloppyOverlapQuda() +{ + if (!initialized) errorQuda("QUDA not initialized"); + + // Delete overlapRefinement if it does not alias overlapSloppy. + if (overlapRefinement != overlapSloppy && overlapRefinement) delete overlapRefinement; + + // Delete overlapPrecondition if it does not alias overlapPrecise, overlapSloppy, or overlapEigensolver. + if (overlapPrecondition != overlapSloppy && overlapPrecondition != overlapPrecise + && overlapPrecondition != overlapEigensolver && overlapPrecondition) + delete overlapPrecondition; + + // Delete overlapEigensolver if it does not alias overlapPrecise or overlapSloppy. + if (overlapEigensolver != overlapSloppy && overlapEigensolver != overlapPrecise && overlapEigensolver) + delete overlapEigensolver; + + // Delete overlapSloppy if it does not alias overlapPrecise. + if (overlapSloppy != overlapPrecise && overlapSloppy) delete overlapSloppy; + + overlapEigensolver = nullptr; + overlapRefinement = nullptr; + overlapPrecondition = nullptr; + overlapSloppy = nullptr; +} + +void freeOverlapQuda(void) +{ + if (!initialized) errorQuda("QUDA not initialized"); + freeSloppyOverlapQuda(); + if (overlapPrecise) { delete overlapPrecise; } + overlapPrecise = nullptr; +} + +void loadSloppyOverlapQuda(const QudaPrecision prec[]) +{ + freeSloppyOverlapQuda(); + + if (overlapPrecise) { + + if (prec[0] == overlapPrecise->Precision()) { + overlapSloppy = overlapPrecise; + } else { + overlapSloppy = new OverlapKernel(overlapPrecise, prec[0]); + } + + // create the mirror preconditioner overlap field + if (prec[1] == overlapPrecise->Precision()) { + overlapPrecondition = overlapPrecise; + } else if (prec[1] == overlapSloppy->Precision()) { + overlapPrecondition = overlapSloppy; + } else { + overlapPrecondition = new OverlapKernel(overlapPrecise, prec[1]); + } + + // create the mirror refinement overlap field + if (prec[2] == overlapSloppy->Precision()) { + overlapRefinement = overlapSloppy; + } else { + overlapRefinement = new OverlapKernel(overlapPrecise, prec[2]); + } + + // create the mirror eigensolver overlap field + if (prec[3] == overlapPrecise->Precision()) { + overlapEigensolver = overlapPrecise; + } else if (prec[3] == overlapSloppy->Precision()) { + overlapEigensolver = overlapSloppy; + } else if (prec[3] == overlapPrecondition->Precision()) { + overlapEigensolver = overlapPrecondition; + } else { + overlapEigensolver = new OverlapKernel(overlapPrecise, prec[3]); + } + } +} + +void loadOverlapQuda(QudaInvertParam *inv_param, QudaEigParam *eig_param) +{ + auto profile = pushProfile(profileOverlap); + pushVerbosity(inv_param->verbosity); + + checkInvertParam(inv_param); + checkEigParam(eig_param); + + if (gaugePrecise == nullptr) errorQuda("Gauge field must be loaded before clover"); + + ColorSpinorParam cpuParam(nullptr, *inv_param, gaugePrecise->X(), false, inv_param->input_location); + ColorSpinorParam cudaParam(cpuParam, *inv_param, QUDA_CUDA_FIELD_LOCATION); + cudaParam.setPrecision(inv_param->cuda_prec, inv_param->cuda_prec, true); + cudaParam.create = QUDA_ZERO_FIELD_CREATE; + cudaParam.gammaBasis = QUDA_UKQCD_GAMMA_BASIS; + + const int n_eig = eig_param->n_conv; + std::vector evals(n_eig, 0.0); + std::vector evecs(n_eig, ColorSpinorField(cudaParam)); + + QudaEigParam eig_param_g5w = newQudaEigParam(); + eig_param_g5w.eig_type = QUDA_EIG_TR_LANCZOS; + eig_param_g5w.spectrum = QUDA_SPECTRUM_SR_EIG; + eig_param_g5w.use_dagger = QUDA_BOOLEAN_FALSE; + eig_param_g5w.use_norm_op = QUDA_BOOLEAN_TRUE; + eig_param_g5w.use_pc = QUDA_BOOLEAN_FALSE; + eig_param_g5w.compute_gamma5 = QUDA_BOOLEAN_FALSE; + eig_param_g5w.batched_rotate = 1; // Save device memory + eig_param_g5w.compute_evals_batch_size = 1; + + eig_param_g5w.use_poly_acc = eig_param->use_poly_acc; + eig_param_g5w.poly_deg = eig_param->poly_deg; + eig_param_g5w.a_min = eig_param->a_min * eig_param->a_min; + eig_param_g5w.a_max = (1 + 8 * inv_param->kappa) * (1 + 8 * inv_param->kappa); + eig_param_g5w.n_ev = eig_param->n_ev; + eig_param_g5w.n_kr = eig_param->n_kr; + eig_param_g5w.n_conv = eig_param->n_conv; + eig_param_g5w.tol = eig_param->tol; + eig_param_g5w.max_restarts = eig_param->max_restarts; + strcpy(eig_param_g5w.vec_infile, eig_param->vec_infile); + strcpy(eig_param_g5w.vec_outfile, eig_param->vec_outfile); + + DiracParam diracParam; + setDiracParam(diracParam, inv_param, false); + Dirac *d = new DiracWilson(diracParam); + + DiracMatrix *m = new DiracMdagM(*d); + auto *eig_solve = quda::EigenSolver::create(&eig_param_g5w, *m); + (*eig_solve)(evecs, evals); + delete eig_solve; + + // Recalculate eigenvalues + delete m; + m = new DiracG5M(*d); + ColorSpinorField tmp(cudaParam); + for (int i = 0; i < n_eig; ++i) { + (*m)(tmp, evecs[i]); + evals[i] = blas::cDotProduct(tmp, evecs[i]); + } + + delete m; + delete d; + + freeOverlapQuda(); + std::vector remez_tol(1, inv_param->overlap_invsqrt_tol); + overlapPrecise = new OverlapKernel(evecs, evals, inv_param->kappa, remez_tol); + QudaPrecision prec[] = {inv_param->cuda_prec_sloppy, inv_param->cuda_prec_precondition, + inv_param->cuda_prec_refinement_sloppy, inv_param->cuda_prec_eigensolver}; + loadSloppyOverlapQuda(prec); + + flushPoolQuda(QUDA_MEMORY_DEVICE); + + popVerbosity(); +} + // just free the sloppy fields used in mixed-precision solvers void freeSloppyGaugeQuda() { @@ -1515,6 +1675,7 @@ void endQuda(void) freeGaugeQuda(); freeCloverQuda(); + freeOverlapQuda(); flushChrono(); @@ -1560,6 +1721,7 @@ void endQuda(void) profileInit.Print(); profileGauge.Print(); profileClover.Print(); + profileOverlap.Print(); profileDslash.Print(); profileInvert.Print(); profileInvertMultiSrc.Print(); @@ -1658,6 +1820,7 @@ namespace quda { memcpy(diracParam.b_5, inv_param->b_5, sizeof(Complex) * inv_param->Ls); memcpy(diracParam.c_5, inv_param->c_5, sizeof(Complex) * inv_param->Ls); break; + case QUDA_OVERLAP_DSLASH: diracParam.type = QUDA_OVERLAP_DIRAC; break; case QUDA_STAGGERED_DSLASH: diracParam.type = pc ? QUDA_STAGGEREDPC_DIRAC : QUDA_STAGGERED_DIRAC; break; @@ -1702,6 +1865,7 @@ namespace quda { diracParam.fatGauge = gaugeFatPrecise; diracParam.longGauge = gaugeLongPrecise; diracParam.clover = cloverPrecise; + diracParam.overlap_kernel = overlapPrecise; diracParam.kappa = kappa; diracParam.mass = inv_param->mass; diracParam.m5 = inv_param->m5; @@ -1727,6 +1891,7 @@ namespace quda { diracParam.fatGauge = gaugeFatSloppy; diracParam.longGauge = gaugeLongSloppy; diracParam.clover = cloverSloppy; + diracParam.overlap_kernel = overlapSloppy; for (int i=0; i<4; i++) { diracParam.commDim[i] = 1; // comms are always on @@ -1745,6 +1910,7 @@ namespace quda { diracParam.fatGauge = gaugeFatRefinement; diracParam.longGauge = gaugeLongRefinement; diracParam.clover = cloverRefinement; + diracParam.overlap_kernel = overlapRefinement; for (int i=0; i<4; i++) { diracParam.commDim[i] = 1; // comms are always on @@ -1770,6 +1936,7 @@ namespace quda { diracParam.longGauge = gaugeLongPrecondition; } diracParam.clover = cloverPrecondition; + diracParam.overlap_kernel = overlapPrecondition; for (int i=0; i<4; i++) { diracParam.commDim[i] = comms ? 1 : 0; @@ -1815,6 +1982,7 @@ namespace quda { diracParam.longGauge = gaugeLongEigensolver; } diracParam.clover = cloverEigensolver; + diracParam.overlap_kernel = overlapEigensolver; for (int i = 0; i < 4; i++) { diracParam.commDim[i] = 1; } @@ -1899,6 +2067,8 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity errorQuda("Gauge field not allocated"); if (cloverPrecise == nullptr && ((inv_param->dslash_type == QUDA_CLOVER_WILSON_DSLASH) || (inv_param->dslash_type == QUDA_TWISTED_CLOVER_DSLASH))) errorQuda("Clover field not allocated"); + if (overlapPrecise == nullptr && (inv_param->dslash_type == QUDA_OVERLAP_DSLASH)) + errorQuda("Overlap kernel not allocated"); pushVerbosity(inv_param->verbosity); if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printQudaInvertParam(inv_param); @@ -2394,6 +2564,8 @@ void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) errorQuda("Gauge field not allocated"); if (cloverPrecise == nullptr && ((inv_param->dslash_type == QUDA_CLOVER_WILSON_DSLASH) || (inv_param->dslash_type == QUDA_TWISTED_CLOVER_DSLASH))) errorQuda("Clover field not allocated"); + if (overlapPrecise == nullptr && (inv_param->dslash_type == QUDA_OVERLAP_DSLASH)) + errorQuda("Overlap kernel not allocated"); if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printQudaInvertParam(inv_param); bool pc = (inv_param->solution_type == QUDA_MATPC_SOLUTION || @@ -2458,6 +2630,8 @@ void MatDagMatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) errorQuda("Gauge field not allocated"); if (cloverPrecise == nullptr && ((inv_param->dslash_type == QUDA_CLOVER_WILSON_DSLASH) || (inv_param->dslash_type == QUDA_TWISTED_CLOVER_DSLASH))) errorQuda("Clover field not allocated"); + if (overlapPrecise == nullptr && (inv_param->dslash_type == QUDA_OVERLAP_DSLASH)) + errorQuda("Overlap kernel not allocated"); if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printQudaInvertParam(inv_param); bool pc = (inv_param->solution_type == QUDA_MATPC_SOLUTION || @@ -2554,6 +2728,34 @@ void checkClover(QudaInvertParam *param) { if (cloverEigensolver == nullptr) errorQuda("Eigensolver clover field doesn't exist"); } +void checkOverlap(QudaInvertParam *param) +{ + if (param->dslash_type != QUDA_OVERLAP_DSLASH) { return; } + + if (param->cuda_prec != overlapPrecise->Precision()) { + errorQuda("Solve precision %d doesn't match overlap precision %d", param->cuda_prec, overlapPrecise->Precision()); + } + if (param->kappa != overlapPrecise->Kappa()) { + errorQuda("Solve kappa %f doesn't match overlap kappa %f", param->kappa, overlapPrecise->Kappa()); + } + + if ((!overlapSloppy || param->cuda_prec_sloppy != overlapSloppy->Precision()) + || (!overlapPrecondition || param->cuda_prec_precondition != overlapPrecondition->Precision()) + || (!overlapRefinement || param->cuda_prec_refinement_sloppy != overlapRefinement->Precision()) + || (!overlapEigensolver || param->cuda_prec_eigensolver != overlapEigensolver->Precision())) { + freeSloppyOverlapQuda(); + QudaPrecision prec[4] = {param->cuda_prec_sloppy, param->cuda_prec_precondition, param->cuda_prec_refinement_sloppy, + param->cuda_prec_eigensolver}; + loadSloppyOverlapQuda(prec); + } + + if (overlapPrecise == nullptr) errorQuda("Precise overlap kernel doesn't exist"); + if (overlapSloppy == nullptr) errorQuda("Sloppy overlap kernel doesn't exist"); + if (overlapPrecondition == nullptr) errorQuda("Precondition kernel field doesn't exist"); + if (overlapRefinement == nullptr) errorQuda("Refinement kernel field doesn't exist"); + if (overlapEigensolver == nullptr) errorQuda("Eigensolver kernel field doesn't exist"); +} + quda::GaugeField *checkGauge(QudaInvertParam *param) { quda::GaugeField *U = param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise : @@ -2619,6 +2821,7 @@ quda::GaugeField *checkGauge(QudaInvertParam *param) } checkClover(param); + checkOverlap(param); return U; } @@ -2695,7 +2898,10 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam // the correct QudaInvertParam values for the solve_type and // solution_type based on those three booleans - if (eig_param->use_pc) { + if (eig_param->chirality != QUDA_INVALID_CHIRALITY) { + inv_param->solve_type = QUDA_NORMOP_CHIRAL_SOLVE; + inv_param->solution_type = QUDA_MAT_SOLUTION; + } else if (eig_param->use_pc) { if (eig_param->use_norm_op) inv_param->solve_type = QUDA_NORMOP_PC_SOLVE; else @@ -2735,7 +2941,7 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam Dirac *dEig = nullptr; // Create the dirac operator with a sloppy and a precon. - bool pc_solve = (inv_param->solve_type == QUDA_DIRECT_PC_SOLVE) || (inv_param->solve_type == QUDA_NORMOP_PC_SOLVE); + bool pc_solve = eig_param->use_pc; createDiracWithEig(d, dSloppy, dPre, dEig, *inv_param, pc_solve, eig_param->use_smeared_gauge); Dirac &dirac = *dEig; //------------------------------------------------------ @@ -2743,7 +2949,8 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam // Construct vectors //------------------------------------------------------ // Create host wrappers around application vector set - ColorSpinorParam cpuParam(nullptr, *inv_param, cudaGauge->X(), inv_param->solution_type, inv_param->input_location); + ColorSpinorParam cpuParam(nullptr, *inv_param, cudaGauge->X(), pc_solve, inv_param->input_location); + if (eig_param->chirality != QUDA_INVALID_CHIRALITY) { cpuParam.nSpin = 2; } int n_eig = eig_param->n_conv; if (eig_param->compute_svd) n_eig *= 2; @@ -2765,8 +2972,9 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam ColorSpinorParam cudaParam(cpuParam, *inv_param, QUDA_CUDA_FIELD_LOCATION); cudaParam.create = QUDA_ZERO_FIELD_CREATE; cudaParam.setPrecision(inv_param->cuda_prec_eigensolver, inv_param->cuda_prec_eigensolver, true); + cudaParam.mem_type = eig_param->mem_type_ritz; // Ensure device vectors qre in UKQCD basis for Wilson type fermions - if (cudaParam.nSpin != 1) cudaParam.gammaBasis = QUDA_UKQCD_GAMMA_BASIS; + if (cudaParam.nSpin == 4) cudaParam.gammaBasis = QUDA_UKQCD_GAMMA_BASIS; std::vector kSpace(n_eig); for (int i = 0; i < n_eig; i++) { @@ -2805,7 +3013,10 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam // multiply by gamma5. Each combination requires a unique Dirac operator // object. DiracMatrix *m = nullptr; - if (!eig_param->use_norm_op && !eig_param->use_dagger && eig_param->compute_gamma5) { + if (eig_param->chirality != QUDA_INVALID_CHIRALITY) { + m = new DiracMdagMChiral(dirac); + ((DiracMdagMChiral *)m)->setChirality(eig_param->chirality); + } else if (!eig_param->use_norm_op && !eig_param->use_dagger && eig_param->compute_gamma5) { m = new DiracG5M(dirac); } else if (!eig_param->use_norm_op && !eig_param->use_dagger && !eig_param->compute_gamma5) { m = new DiracM(dirac); @@ -3097,7 +3308,7 @@ deflated_solver::deflated_solver(QudaEigParam &eig_param, TimeProfile &profile) if (ritzParam.location==QUDA_CUDA_FIELD_LOCATION) { ritzParam.setPrecision(param->cuda_prec_ritz, param->cuda_prec_ritz, true); // set native field order - if (ritzParam.nSpin != 1) ritzParam.gammaBasis = QUDA_UKQCD_GAMMA_BASIS; + if (ritzParam.nSpin == 4) ritzParam.gammaBasis = QUDA_UKQCD_GAMMA_BASIS; //select memory location here, by default ritz vectors will be allocated on the device //but if not sufficient device memory, then the user may choose mapped type of memory @@ -3629,6 +3840,40 @@ void dslashMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, Quda callMultiSrcQuda(_hp_x, _hp_b, param, op, parity); } +namespace quda +{ + void separateChiral(std::vector &b_left, std::vector &b_right, + const ColorSpinorField &b, double nb) + { + ColorSpinorParam chiralParam(b); + chiralParam.nSpin = 2; + chiralParam.gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; + chiralParam.setPrecision(chiralParam.Precision(), chiralParam.Precision(), true); + b_left.resize(0); + b_right.resize(0); + { + ColorSpinorField tmp_left(chiralParam), tmp_right(chiralParam); + spinorChiralProject(tmp_left, tmp_right, b); + if (blas::norm2(tmp_left) / nb > 1e-16) { b_left.push_back(std::move(tmp_left)); } + if (blas::norm2(tmp_right) / nb > 1e-16) { b_right.push_back(std::move(tmp_right)); } + } + } + + void combineChiral(cvector_ref &x_left, cvector_ref &x_right, + cvector_ref &x) + { + auto tmp = getFieldTmp(x[0]); + for (size_t i = 0; i < x_left.size(); i++) { + spinorChiralReconstruct(tmp, x_left[i], QUDA_LEFT_CHIRALITY); + blas::xpy(tmp, x[i]); + } + for (size_t i = 0; i < x_right.size(); i++) { + spinorChiralReconstruct(tmp, x_right[i], QUDA_RIGHT_CHIRALITY); + blas::xpy(tmp, x[i]); + } + } +} // namespace quda + /*! * Generic version of the multi-shift solver. Should work for * most fermions. Note that offset[0] is not folded into the mass parameter. @@ -3663,6 +3908,7 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) bool pc_solve = (param->solve_type == QUDA_DIRECT_PC_SOLVE) || (param->solve_type == QUDA_NORMOP_PC_SOLVE); bool mat_solution = (param->solution_type == QUDA_MAT_SOLUTION) || (param->solution_type == QUDA_MATPC_SOLUTION); bool direct_solve = (param->solve_type == QUDA_DIRECT_SOLVE) || (param->solve_type == QUDA_DIRECT_PC_SOLVE); + bool chiral_solve = (param->solve_type == QUDA_NORMERR_CHIRAL_SOLVE); if (param->dslash_type == QUDA_ASQTAD_DSLASH || param->dslash_type == QUDA_STAGGERED_DSLASH) { @@ -3675,6 +3921,19 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) errorQuda("For Staggered-type fermions, multi-shift solver only supports DIRECT_PC solve types"); } + } else if (param->dslash_type == QUDA_OVERLAP_DSLASH) { + + if (!chiral_solve) { + errorQuda("For Overlap fermions, multi-shift solver only support NORMERR_CHIRAL solve types"); + } + if (direct_solve) { + errorQuda("For Overlap fermions, multi-shift solver does not support DIRECT or DIRECT_PC solve types"); + } + if (pc_solution || pc_solve) { + errorQuda( + "For Overlap fermions, multi-shift solver does not support preconditioned (PC) solution_type or solve_type"); + } + } else { // Wilson type if (mat_solution) { @@ -3716,6 +3975,11 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) param->mass = sqrt(param->offset[0]/4); } + // We solve m / (1 - m) + D in multi-shift solver + // But we actually use m + (1 - m) D as DiracOverlap::M() + // so mass = 0 here to get a D without any shift + if (param->dslash_type == QUDA_OVERLAP_DSLASH) { param->mass = 0.0; } + Dirac *d = nullptr; Dirac *dSloppy = nullptr; Dirac *dPre = nullptr; @@ -3800,13 +4064,40 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) param->dslash_type == QUDA_STAGGERED_DSLASH) { m = new DiracM(dirac); mSloppy = new DiracM(diracSloppy); + } else if (chiral_solve) { + m = new DiracMdagMChiral(dirac); + mSloppy = new DiracMdagMChiral(diracSloppy); } else { m = new DiracMdagM(dirac); mSloppy = new DiracMdagM(diracSloppy); } + std::vector b_left, b_right; + if (chiral_solve) { + cudaParam.create = QUDA_NULL_FIELD_CREATE; + cudaParam.nSpin = 2; + cudaParam.gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; + cudaParam.setPrecision(cudaParam.Precision(), cudaParam.Precision(), true); + separateChiral(b_left, b_right, b, nb); + blas::zero(x); + } + std::vector x_left(b_left.size() * param->num_offset, cudaParam); + std::vector x_right(b_right.size() * param->num_offset, cudaParam); + SolverParam solverParam(*param); - { + if (chiral_solve) { + // high-mode propagator for chiral overlap fermion + for (QudaChirality chirality : {QUDA_LEFT_CHIRALITY, QUDA_RIGHT_CHIRALITY}) { + auto &b_chiral = (chirality == QUDA_LEFT_CHIRALITY) ? b_left : b_right; + auto &x_chiral = (chirality == QUDA_LEFT_CHIRALITY) ? x_left : x_right; + ((DiracMdagMChiral *)m)->setChirality(chirality); + ((DiracMdagMChiral *)mSloppy)->setChirality(chirality); + if (b_chiral.size() > 0) { + MultiShiftCG cg_m(*m, *mSloppy, solverParam); + cg_m(x_chiral, b_chiral[0], p, r2_old); + } + } + } else { MultiShiftCG cg_m(*m, *mSloppy, solverParam); cg_m(x, b, p, r2_old); } @@ -3864,6 +4155,9 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) param->dslash_type == QUDA_STAGGERED_DSLASH) { m = new DiracM(dirac); mSloppy = new DiracM(diracSloppy); + } else if (chiral_solve) { + m = new DiracMdagMChiral(dirac); + mSloppy = new DiracMdagMChiral(diracSloppy); } else { m = new DiracMdagM(dirac); mSloppy = new DiracMdagM(diracSloppy); @@ -3908,7 +4202,21 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) solverParam.tol_hq = param->tol_hq_offset[i]; // set heavy quark tolerance solverParam.delta = param->reliable_delta_refinement; - { + if (chiral_solve) { + for (QudaChirality chirality : {QUDA_LEFT_CHIRALITY, QUDA_RIGHT_CHIRALITY}) { + auto &b_chiral = (chirality == QUDA_LEFT_CHIRALITY) ? b_left : b_right; + auto &x_chiral = (chirality == QUDA_LEFT_CHIRALITY) ? x_left : x_right; + ((DiracMdagMChiral *)m)->setChirality(chirality); + ((DiracMdagMChiral *)mSloppy)->setChirality(chirality); + if (b_chiral.size() > 0) { + CG cg(*m, *mSloppy, *mSloppy, *mSloppy, solverParam); + if (i == 0) + cg(x_chiral[i], b_chiral[0], p[i], r2_old[i]); + else + cg(x_chiral[i], b_chiral[0]); + } + } + } else { CG cg(*m, *mSloppy, *mSloppy, *mSloppy, solverParam); if (i == 0) cg(x[i], b, p[i], r2_old[i]); @@ -3919,7 +4227,6 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) solverParam.true_res_offset[i] = static_cast(solverParam.true_res); solverParam.true_res_hq_offset[i] = static_cast(solverParam.true_res_hq); solverParam.updateInvertParam(*param,i); - if (param->dslash_type == QUDA_ASQTAD_DSLASH || param->dslash_type == QUDA_STAGGERED_DSLASH) { dirac.setMass(sqrt(param->offset[0]/4)); // restore just in case @@ -3932,6 +4239,24 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) } } + if (chiral_solve) { combineChiral(x_left, x_right, x); } + + // We have to reconstruct the solution for overlap fermions + if (param->dslash_type == QUDA_OVERLAP_DSLASH) { + auto tmp = getFieldTmp(x[0]); + for (int i = 0; i < param->num_offset; i++) { + double mass = sqrt(param->offset[i] / (param->offset[i] + 1.0)); + // (m^2 / (1 - m^2) + D)^{-1} ==> (m^2 + (1 - m^2) D)^{-1} + blas::ax(1 / (1 - mass * mass), x[i]); + d->setMass(mass); + if (mat_solution) { + blas::copy(tmp, x[i]); + d->Mdag(x[i], tmp); + } + d->reconstruct(x[i], b, param->solution_type); + } + } + // restore shifts for (int i = 0; i < param->num_offset; i++) param->offset[i] = unscaled_shifts[i]; @@ -5270,7 +5595,7 @@ void performTwoLinkGaussianSmearNStep(void *h_in, QudaQuarkSmearParam *smear_par inv_param->dslash_type = QUDA_ASQTAD_DSLASH; - ColorSpinorParam cpuParam(h_in, *inv_param, X, QUDA_MAT_SOLUTION, QUDA_CPU_FIELD_LOCATION); + ColorSpinorParam cpuParam(h_in, *inv_param, X, false, QUDA_CPU_FIELD_LOCATION); cpuParam.nSpin = 1; // QUDA style pointer for host data. ColorSpinorField in_h(cpuParam); @@ -5295,6 +5620,7 @@ void performTwoLinkGaussianSmearNStep(void *h_in, QudaQuarkSmearParam *smear_par diracParam.fatGauge = gaugeFatPrecise; diracParam.longGauge = gaugeLongPrecise; diracParam.clover = cloverPrecise; + diracParam.overlap_kernel = overlapPrecise; diracParam.kappa = inv_param->kappa; diracParam.mass = inv_param->mass; diracParam.m5 = inv_param->m5; diff --git a/lib/inv_multi_cg_quda.cpp b/lib/inv_multi_cg_quda.cpp index dba0e11f4e..31b2fbf766 100644 --- a/lib/inv_multi_cg_quda.cpp +++ b/lib/inv_multi_cg_quda.cpp @@ -71,7 +71,7 @@ namespace quda { public: ShiftUpdate(ColorSpinorField &r, std::vector &p, std::vector &x, std::vector &alpha, std::vector &beta, std::vector &zeta, - std::vector &zeta_old, int j_low, int n_shift) : + std::vector &zeta_old, int j_low, int n_shift, int n_update) : r(r), p(p), x(x), @@ -81,7 +81,7 @@ namespace quda { zeta_old(zeta_old), j_low(j_low), n_shift(n_shift), - n_update((r.Nspin() == 4) ? 4 : 2) + n_update(n_update) { } @@ -265,7 +265,8 @@ namespace quda { // now create the worker class for updating the shifted solutions and gradient vectors bool aux_update = false; - ShiftUpdate shift_update(r_sloppy, p, x_sloppy, alpha, beta, zeta, zeta_old, j_low, num_offset_now); + ShiftUpdate shift_update(r_sloppy, p, x_sloppy, alpha, beta, zeta, zeta_old, j_low, num_offset_now, + matSloppy.getStencilSteps()); getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); getProfile().TPSTART(QUDA_PROFILE_COMPUTE); @@ -284,7 +285,7 @@ namespace quda { shift_update.updateNshift(num_offset_now); // at some point we should curry these into the Dirac operator - if (r.Nspin() == 4) + if (r.Nspin() != 1) pAp = blas::axpyReDot(offset[0], p[0], Ap); else pAp = blas::reDotProduct(p[0], Ap); @@ -343,7 +344,7 @@ namespace quda { } mat(r, x[0]); - if (r.Nspin() == 4) blas::axpy(offset[0], x[0], r); + if (r.Nspin() != 1) blas::axpy(offset[0], x[0], r); r2[0] = blas::xmyNorm(b, r); for (int j = 1; j < num_offset_now; j++) r2[j] = zeta[j] * zeta[j] * r2[0]; @@ -451,7 +452,7 @@ namespace quda { // 2.) For shift 0 if we did not exit early (we went to the full solution) if ( (i > 0 and not mixed) or (i == 0 and not exit_early) ) { mat(r, x[i]); - if (r.Nspin() == 4) { + if (r.Nspin() != 1) { blas::axpy(offset[i], x[i], r); // Offset it. } else if (i != 0) { blas::axpy(offset[i] - offset[0], x[i], r); // Offset it. diff --git a/lib/multi_blas_quda.cu b/lib/multi_blas_quda.cu index 0828dc3653..ce0449eb8c 100644 --- a/lib/multi_blas_quda.cu +++ b/lib/multi_blas_quda.cu @@ -99,7 +99,8 @@ namespace quda { staticCheck(f, x, y); constexpr bool site_unroll_check = !std::is_same::value || isFixed::value; - if (site_unroll_check && (x[0].Ncolor() != 3 || x[0].Nspin() == 2)) + // TODO: Is x[0].Nspin() == 2 check needed here? + if (site_unroll_check && (x[0].Ncolor() != 3 && x[0].Nspin() == 2)) errorQuda("site unroll not supported for nSpin = %d nColor = %d", x[0].Nspin(), x[0].Ncolor()); TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); @@ -115,7 +116,8 @@ namespace quda { constexpr bool site_unroll = !std::is_same::value || isFixed::value; constexpr int N = n_vector(nSpin, site_unroll); constexpr int Ny = n_vector(nSpin, site_unroll); - constexpr int M = site_unroll ? (nSpin == 4 ? 24 : 6) : N; // real numbers per thread + // TODO: Shall we use n_vector(nSpin, true) here? + constexpr int M = site_unroll ? (nSpin * 6) : N; // real numbers per thread const int length = x[0].Length() / (nParity * M); if (tp.aux.x > 1 && (length * tp.aux.x) % device::warp_size() != 0) { diff --git a/lib/multi_reduce_quda.cu b/lib/multi_reduce_quda.cu index e718209729..69d08135e5 100644 --- a/lib/multi_reduce_quda.cu +++ b/lib/multi_reduce_quda.cu @@ -98,7 +98,8 @@ namespace quda { auto &x0 = x[0]; constexpr bool site_unroll_check = !std::is_same::value || isFixed::value; - if (site_unroll_check && (x0.Ncolor() != 3 || x0.Nspin() == 2)) + // TODO: Is x0.Nspin() == 2 check needed here? + if (site_unroll_check && (x0.Ncolor() != 3 && x0.Nspin() == 2)) errorQuda("site unroll not supported for nSpin = %d nColor = %d", x0.Nspin(), x0.Ncolor()); TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); @@ -114,7 +115,8 @@ namespace quda { constexpr bool site_unroll = !std::is_same::value || isFixed::value; constexpr int N = n_vector(nSpin, site_unroll); constexpr int Ny = n_vector(nSpin, site_unroll); - constexpr int M = site_unroll ? (nSpin == 4 ? 24 : 6) : N; // real numbers per thread + // TODO: Shall we use n_vector(nSpin, true) here? + constexpr int M = site_unroll ? (nSpin * 6) : N; // real numbers per thread const int length = x0.Length() / M; MultiReduceArg arg(x, y, z, w, r_, NYW, length, nParity); @@ -651,7 +653,9 @@ namespace quda { max_YW_size>(x.size(), x0.Precision(), y0.Precision()); // if fine-grid then we set max tile size to 32 to avoid unnecessary tuning - uint2 max_tile_size = make_uint2(1, std::min({NYW_max, (int)y.size(), x0.Ncolor() == 3 ? 32 : NYW_max})); + // if not on device memory we set max tile size to 8 to avoid extreamly long tuning + const int x0_max_tile_size = x0.Ncolor() == 3 ? (x0.MemType() == QUDA_MEMORY_DEVICE ? 32 : 8) : NYW_max; + uint2 max_tile_size = make_uint2(1, std::min({NYW_max, (int)y.size(), x0_max_tile_size})); multiReduce_recurse(result_tmp, x, y, x, x, 0, 0, false, max_tile_size); } else if (y.size() == 1 && x0.Precision() == y0.Precision()) { @@ -663,7 +667,9 @@ namespace quda { max_YW_size>(y.size(), y0.Precision(), x0.Precision()); // if fine-grid then we set max tile size to 32 to avoid unnecessary tuning - uint2 max_tile_size = make_uint2(1, std::min({NXZ_max, (int)x.size(), x0.Ncolor() == 3 ? 32 : NXZ_max})); + // if not on device memory we set max tile size to 8 to avoid extreamly long tuning + const int x0_max_tile_size = x0.Ncolor() == 3 ? (x0.MemType() == QUDA_MEMORY_DEVICE ? 32 : 8) : NXZ_max; + uint2 max_tile_size = make_uint2(1, std::min({NXZ_max, (int)x.size(), x0_max_tile_size})); multiReduce_recurse(result_trans, y, x, y, y, 0, 0, false, max_tile_size); // transpose the result if we are doing the transpose calculation diff --git a/lib/overlap_kernel.cpp b/lib/overlap_kernel.cpp new file mode 100644 index 0000000000..5fe85dc250 --- /dev/null +++ b/lib/overlap_kernel.cpp @@ -0,0 +1,180 @@ +#include +#include +#include + +namespace quda +{ + // Chebyshev polynomial the first kind + // T_{k+1}(x) = 2 x T_k(x) - T_{k-1}(x) + double Tn(double x, int n) + { + if (abs(x) <= 1.0) { return cos(n * std::acos(x)); } + double T0 = 1, T1 = x, Tk = 2 * x * x - 1; + switch (n) { + case 0: return T0; + case 1: return T1; + case 2: return Tk; + default: + for (int k = 3; k <= n; ++k) { + T0 = T1; + T1 = Tk; + Tk = 2 * x * T1 - T0; + } + return Tk; + } + } + + // \sum_{i=0}^n c_i T_i + // T_{k+1}(x) = 2 x T_k(x) - T_{k-1}(x) + // Use Clenshaw algorithm + double ciTi(double x, std::vector c, int n) + { + double b2 = 0.0, b1 = 0.0, bk; + for (int k = n; k >= 1; --k) { + bk = c[k] + 2 * x * b1 - b2; + b2 = b1; + b1 = bk; + } + return c[0] + x * b1 - b2; + } + + // (\sum_{i=0}^n c_i T_i)' = \sum_{i=1}^n i c_i U_{i-1} + // U_{k+1}(x) = 2 x U_k(x) - U_{k-1} + // Use Clenshaw algorithm + double iciUim1(double x, std::vector c, int n) + { + double b2 = 0.0, b1 = 0.0, bk; + for (int k = n - 1; k >= 1; --k) { + bk = (k + 1) * c[k + 1] + 2 * x * b1 - b2; + b2 = b1; + b1 = bk; + } + return c[1] + 2 * x * b1 - b2; + } + + double residual(double x, std::vector c, int n, double epsilon, bool derivative) + { + const double z = (x * 2 - (1 + epsilon)) / (1 - epsilon); + if (derivative) { + return -1 / (2 * sqrt(x)) * ciTi(z, c, n) - sqrt(x) * iciUim1(z, c, n) * (2 / (1 - epsilon)); + } else { + return 1 - sqrt(x) * ciTi(z, c, n); + } + } + + double findRoot(double x_l, double x_r, std::vector c, int n, double epsilon, bool derivative) + { + double x_m, res_r, res_l, res_m; + + res_l = residual(x_l, c, n, epsilon, derivative); + res_r = residual(x_r, c, n, epsilon, derivative); + if (abs(res_l) < 1e-15) return x_l; + if (abs(res_r) < 1e-15) return x_r; + if (res_r * res_l > 0) + errorQuda("ERROR: findRoot with derivative=%d called with wrong ends: (%e %e)->(%e %e)\n", derivative, x_l, x_r, + res_l, res_r); + for (int i = 0; i < 10; i++) { + x_m = (res_l * x_r - res_r * x_l) / (res_l - res_r); + res_m = residual(x_m, c, n, epsilon, derivative); + if (res_m * res_l > 0) { + x_l = x_m; + res_l = res_m; + } else { + x_r = x_m; + res_r = res_m; + } + } + return (res_l * x_r - res_r * x_l) / (res_l - res_r); + } + + std::vector minimaxApproximationRemez(double delta, double epsilon) + { + const int n_ref = ceil(-log(delta / 0.41) / (2.083 * sqrt(epsilon))) + 1; + bool converged = false; + constexpr int max_iter = 5; + std::vector y, z, c, b; + for (int n = n_ref; n < n_ref * 1.1; n++) { + y.resize(n + 1); + z.resize(n + 1); + c.resize(n + 1); + b.resize(n + 1); + Eigen::Map b_eigen(b.data(), b.size()), c_eigen(c.data(), c.size()); + Eigen::MatrixXd M_eigen(n + 1, n + 1); + + for (int i = 0; i < n + 1; ++i) { + z[i] = cos(M_PI * i / n); + y[i] = (z[i] * (1 - epsilon) + (1 + epsilon)) / 2; + } + + int iter = 0; + while (iter < max_iter) { + // Construct matrix M_ij=\sqrt{y_i}T_j(z_i) + for (int i = 0; i < n + 1; ++i) { + for (int j = 0; j < n; ++j) { M_eigen(i, j) = sqrt(y[i]) * Tn(z[i], j); } + M_eigen(i, n) = i % 2 == 0 ? 1 : -1; // T_n is not a real Chebyshev polynomial + b_eigen(i) = 1.0; + } + c_eigen = M_eigen.lu().solve(b_eigen); + + // Drop T_n + for (int i = 0; i < n; ++i) { b[i] = findRoot(y[i], y[i + 1], c, n - 1, epsilon, false); } + for (int i = n - 1; i > 0; --i) { y[i] = findRoot(b[i], b[i - 1], c, n - 1, epsilon, true); } + for (int i = 1; i < n; ++i) { z[i] = (2 * y[i] - (1 + epsilon)) / (1 - epsilon); } + for (int i = 0; i < n + 1; ++i) { b[i] = abs(1 - sqrt(y[i]) * ciTi(z[i], c, n - 1)); } + if (*std::max_element(b.begin(), b.end()) <= delta) { break; } + iter += 1; + } + if (iter < max_iter) { + converged = true; + break; + } + } + if (!converged) errorQuda("Remez algorithm did not converge\n"); + return {c.begin(), c.end() - 1}; + } + + OverlapKernel::OverlapKernel(std::vector &evecs, const std::vector &evals, double kappa, + const std::vector remez_tol) : + evals(evals.size()), + kappa(kappa), + epsilon(pow(evals.back().real() / (1.0 + 8.0 * kappa), 2)), + remez_tol(remez_tol), + remez_coeff(remez_tol.size()), + remez_order(remez_tol.size()) + { + this->evecs = std::move(evecs); + for (size_t i = 0; i < evals.size(); i++) { this->evals[i] = evals[i].real(); } + for (size_t i = 0; i < remez_tol.size(); i++) { + remez_coeff[i] = minimaxApproximationRemez(remez_tol[i], epsilon); + remez_order[i] = remez_coeff[i].size() - 1; + } + } + + OverlapKernel::OverlapKernel(const OverlapKernel *overlap_kernel, QudaPrecision precision) : + evals(overlap_kernel->evals), + kappa(overlap_kernel->kappa), + epsilon(overlap_kernel->epsilon), + remez_tol(overlap_kernel->remez_tol), + remez_coeff(overlap_kernel->remez_tol.size()), + remez_order(overlap_kernel->remez_tol.size()) + { + ColorSpinorParam param(overlap_kernel->evecs[0]); + param.setPrecision(precision, precision, true); + evecs.resize(overlap_kernel->evecs.size(), ColorSpinorField(param)); + for (size_t i = 0; i < overlap_kernel->evecs.size(); i++) { evecs[i].copy(overlap_kernel->evecs[i]); } + double prec_tol = 0.0; + switch (precision) { + case QUDA_DOUBLE_PRECISION: prec_tol = std::numeric_limits::epsilon() / 2.; break; + case QUDA_SINGLE_PRECISION: prec_tol = std::numeric_limits::epsilon() / 2.; break; + case QUDA_HALF_PRECISION: prec_tol = pow(2., -16); break; + case QUDA_QUARTER_PRECISION: prec_tol = pow(2., -8); break; + default: errorQuda("Invalid precision %d", precision); break; + } + for (size_t i = 0; i < remez_tol.size(); i++) { + double tol = std::max(remez_tol[i], prec_tol); + remez_tol[i] = tol; + remez_coeff[i] = minimaxApproximationRemez(tol, epsilon); + remez_order[i] = remez_coeff[i].size() - 1; + } + } +} // namespace quda diff --git a/lib/quda_ptr.cpp b/lib/quda_ptr.cpp index ac1a9bdd8b..a657c2a342 100644 --- a/lib/quda_ptr.cpp +++ b/lib/quda_ptr.cpp @@ -10,7 +10,7 @@ namespace quda { getProfile().TPSTART(QUDA_PROFILE_INIT); if (pool && (type != QUDA_MEMORY_DEVICE && type != QUDA_MEMORY_HOST_PINNED && type != QUDA_MEMORY_HOST)) - errorQuda("Memory pool not available for memory type %d", type); + warningQuda("Memory pool not available for memory type %d", type); if (size > 0) { switch (type) { @@ -46,6 +46,10 @@ namespace quda device = nullptr; host = ptr; break; + case QUDA_MEMORY_MAPPED: // TODO: Is this needed here? + host = ptr; + device = (ptr != nullptr) ? get_mapped_device_pointer(ptr) : nullptr; + break; case QUDA_MEMORY_MANAGED: device = ptr; host = ptr; @@ -78,6 +82,7 @@ namespace quda case QUDA_MEMORY_HOST: host_free(host); break; case QUDA_MEMORY_HOST_PINNED: pool ? pool_pinned_free(host) : host_free(host); break; case QUDA_MEMORY_MAPPED: host_free(host); break; + case QUDA_MEMORY_MANAGED: managed_free(host); break; // TODO: Is this needed here? default: errorQuda("Unknown memory type %d", type); } getProfile().TPSTOP(QUDA_PROFILE_FREE); diff --git a/lib/reduce_quda.cu b/lib/reduce_quda.cu index ea9ce04437..3972bfbace 100644 --- a/lib/reduce_quda.cu +++ b/lib/reduce_quda.cu @@ -64,7 +64,8 @@ namespace quda { void apply(const qudaStream_t &stream) override { constexpr bool site_unroll_check = !std::is_same::value || isFixed::value || decltype(r)::site_unroll; - if (site_unroll_check && (x.Ncolor() != 3 || x.Nspin() == 2)) + // TODO: Is x.Nspin() == 2 check needed here? + if (site_unroll_check && (x.Ncolor() != 3 && x.Nspin() == 2)) errorQuda("site unroll not supported for nSpin = %d nColor = %d", x.Nspin(), x.Ncolor()); TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); @@ -79,7 +80,8 @@ namespace quda { constexpr bool site_unroll = !std::is_same::value || isFixed::value || decltype(r)::site_unroll; constexpr int N = n_vector(nSpin, site_unroll); constexpr int Ny = n_vector(nSpin, site_unroll); - constexpr int M = site_unroll ? (nSpin == 4 ? 24 : 6) : N; // real numbers per thread + // TODO: Shall we use n_vector(nSpin, true) here? + constexpr int M = site_unroll ? (nSpin * 6) : N; // real numbers per thread const int length = x.Length() / M; ReductionArg arg(x, y, z, w, v, r_, length, nParity); diff --git a/lib/solve.cpp b/lib/solve.cpp index 79fa868633..c4c45601c1 100644 --- a/lib/solve.cpp +++ b/lib/solve.cpp @@ -33,6 +33,20 @@ namespace quda for (auto &b2i : b2) printfQuda("Mass rescale: norm of source in = %g\n", b2i); } + // overlap dslash uses mass normalization internally + if (param.dslash_type == QUDA_OVERLAP_DSLASH) { + switch (param.solution_type) { + case QUDA_MAT_SOLUTION: + if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(param.mass, b); + break; + case QUDA_MATDAG_MAT_SOLUTION: + if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(param.mass * param.mass, b); + break; + default: errorQuda("Not implemented"); + } + return; + } + // staggered dslash uses mass normalization internally if (param.dslash_type == QUDA_ASQTAD_DSLASH || param.dslash_type == QUDA_STAGGERED_DSLASH) { switch (param.solution_type) { @@ -125,6 +139,46 @@ namespace quda } } + void separateChiral(std::vector &idx_left, std::vector &in_left, + std::vector &idx_right, std::vector &in_right, + cvector_ref &in, std::vector &nb) + { + ColorSpinorParam chiralParam(in[0]); + chiralParam.nSpin = 2; + chiralParam.gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; + chiralParam.setPrecision(chiralParam.Precision(), chiralParam.Precision(), true); + in_left.resize(0); + in_right.resize(0); + for (size_t i = 0; i < in.size(); i++) { + ColorSpinorField tmp_left(chiralParam); + ColorSpinorField tmp_right(chiralParam); + spinorChiralProject(tmp_left, tmp_right, in[i]); + if (blas::norm2(tmp_left) / nb[i] > 1e-6) { + idx_left.push_back(i); + in_left.push_back(std::move(tmp_left)); + } + if (blas::norm2(tmp_right) / nb[i] > 1e-6) { + idx_right.push_back(i); + in_right.push_back(std::move(tmp_right)); + } + } + } + + void combineChiral(std::vector &idx_left, cvector_ref &out_left, + std::vector &idx_right, cvector_ref &out_right, + cvector_ref &out) + { + auto tmp = getFieldTmp(out[0]); + for (size_t i = 0; i < out_left.size(); i++) { + spinorChiralReconstruct(tmp, out_left[i], QUDA_LEFT_CHIRALITY); + blas::xpy(tmp, out[idx_left[i]]); + } + for (size_t i = 0; i < out_right.size(); i++) { + spinorChiralReconstruct(tmp, out_right[i], QUDA_RIGHT_CHIRALITY); + blas::xpy(tmp, out[idx_right[i]]); + } + } + void solve(cvector_ref &x, cvector_ref &b, Dirac &dirac, Dirac &diracSloppy, Dirac &diracPre, Dirac &diracEig, QudaInvertParam ¶m) { @@ -132,7 +186,9 @@ namespace quda bool mat_solution = (param.solution_type == QUDA_MAT_SOLUTION) || (param.solution_type == QUDA_MATPC_SOLUTION); bool direct_solve = (param.solve_type == QUDA_DIRECT_SOLVE) || (param.solve_type == QUDA_DIRECT_PC_SOLVE); - bool norm_error_solve = (param.solve_type == QUDA_NORMERR_SOLVE) || (param.solve_type == QUDA_NORMERR_PC_SOLVE); + bool norm_error_solve = (param.solve_type == QUDA_NORMERR_SOLVE) || (param.solve_type == QUDA_NORMERR_PC_SOLVE) + || (param.solve_type == QUDA_NORMERR_CHIRAL_SOLVE); + bool chiral_solve = (param.solve_type == QUDA_NORMERR_CHIRAL_SOLVE); auto nb = blas::norm2(b); for (auto &bi : nb) { @@ -187,6 +243,8 @@ namespace quda // MAT NORMOP Solve (A^dag A) x = (A^dag b) // MATDAG_MAT NORMOP Solve (A^dag A) x = b // MAT NORMERR Solve (A A^dag) y = b, then x = A^dag y + // MAT CHIRAL Solve (A A^dag) y = b on both chrialities, then x = A^dag y + // MATDAG_MAT CHIRAL Solve (A A^dag) x = b on both chrialities // // We generally require that the solution_type and solve_type // preconditioning match. As an exception, the unpreconditioned MAT @@ -212,7 +270,37 @@ namespace quda solverParam.updateInvertParam(param); } - if (direct_solve) { + if (chiral_solve) { // (A A^dag) y = b or (A A^dag) x = b on both chiralities + DiracMdagMChiral m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); + SolverParam solverParam(param); + + std::vector idx_left, idx_right; + std::vector in_left, in_right; + separateChiral(idx_left, in_left, idx_right, in_right, in, nb); + auto out_left = getFieldTmp(in_left); + auto out_right = getFieldTmp(in_right); + + for (QudaChirality chirality : {QUDA_LEFT_CHIRALITY, QUDA_RIGHT_CHIRALITY}) { + auto &in_chiral = (chirality == QUDA_LEFT_CHIRALITY) ? in_left : in_right; + auto &out_chiral = (chirality == QUDA_LEFT_CHIRALITY) ? out_left : out_right; + m.setChirality(chirality); + mSloppy.setChirality(chirality); + mPre.setChirality(chirality); + mEig.setChirality(chirality); + if (in_chiral.size() > 0) { + Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); + (*solve)(out_chiral, in_chiral); + delete solve; + solverParam.updateInvertParam(param); + } + } + combineChiral(idx_left, out_left, idx_right, out_right, out); + if (mat_solution) { // then x = A^dag y + auto tmp = getFieldTmp(out); + blas::copy(tmp, out); + dirac.Mdag(out, tmp); + } + } else if (direct_solve) { // A x = b, or A x = y where A^dag y = b DiracM m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); SolverParam solverParam(param); @@ -227,7 +315,7 @@ namespace quda (*solve)(out, in); delete solve; solverParam.updateInvertParam(param); - } else if (!norm_error_solve) { + } else if (!norm_error_solve) { // (A^dag A) x = b, or (A^dag A) x = b' where b' = A^dag b DiracMdagM m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); SolverParam solverParam(param); @@ -251,7 +339,7 @@ namespace quda delete solve; solverParam.updateInvertParam(param); } - } else { // norm_error_solve + } else { // (A A^dag) y = b, then x = A^dag y DiracMMdag m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); auto tmp = getFieldTmp(cvector_ref(in)); SolverParam solverParam(param); diff --git a/lib/spinor_chiral_project.cu b/lib/spinor_chiral_project.cu new file mode 100644 index 0000000000..c83d781180 --- /dev/null +++ b/lib/spinor_chiral_project.cu @@ -0,0 +1,167 @@ +#include +#include +#include +#include + +namespace quda +{ + + template class SpinorChiralReconstruct : TunableKernel2D + { + ColorSpinorField &out; + const ColorSpinorField &in_left; + const ColorSpinorField &in_right; + const QudaChirality chirality; + unsigned int minThreads() const { return out.VolumeCB(); } + + public: + SpinorChiralReconstruct(ColorSpinorField &out, const ColorSpinorField &in_left, const ColorSpinorField &in_right, + QudaChirality chirality) : + TunableKernel2D(out, out.SiteSubset()), out(out), in_left(in_left), in_right(in_right), chirality(chirality) + { + apply(device::get_default_stream()); + } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + if (chirality == QUDA_INVALID_CHIRALITY) { + ChiralReconstructSpinorArg arg(out, in_left, in_right); + launch(tp, stream, arg); + } else if (chirality == QUDA_LEFT_CHIRALITY) { + ChiralReconstructSpinorArg arg(out, in_left, in_right); + launch(tp, stream, arg); + } else if (chirality == QUDA_RIGHT_CHIRALITY) { + ChiralReconstructSpinorArg arg(out, in_left, in_right); + launch(tp, stream, arg); + } else { + errorQuda("Unsupported chirality %d", chirality); + } + } + + long long bytes() const + { + return ((chirality != QUDA_RIGHT_CHIRALITY) ? in_left.Bytes() : 0) + + ((chirality != QUDA_LEFT_CHIRALITY) ? in_right.Bytes() : 0) + out.Bytes(); + } + }; + + void spinorChiralReconstruct(ColorSpinorField &dst, const ColorSpinorField &src_left, + const ColorSpinorField &src_right, QudaChirality chirality) + { + checkPrecision(dst, src_left, src_right); + checkColor(dst, src_left, src_right); + + if (dst.Nspin() != 4 || src_left.Nspin() != 2 || src_right.Nspin() != 2) { + errorQuda("Unsupported nspin combination: dst=%d, src_left=%d, src_right=%d\n", dst.Nspin(), src_left.Nspin(), + src_right.Nspin()); + } + if (dst.GammaBasis() != QUDA_UKQCD_GAMMA_BASIS || src_left.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS + || src_right.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS) { + errorQuda("Unsupported gamma basis combination: dst_left %d, dst_right %d, src %d\n", dst.GammaBasis(), + src_left.GammaBasis(), src_right.GammaBasis()); + } + + if (dst.Ncolor() == 3) { + if (dst.Precision() == QUDA_DOUBLE_PRECISION) { + SpinorChiralReconstruct(dst, src_left, src_right, chirality); + } else if (dst.Precision() == QUDA_SINGLE_PRECISION) { + SpinorChiralReconstruct(dst, src_left, src_right, chirality); + } else { + errorQuda("Precision %d not implemented", dst.Precision()); + } + } else { + errorQuda("nColor=%d not implemented", dst.Ncolor()); + } + } + + void spinorChiralReconstruct(ColorSpinorField &dst, const ColorSpinorField &src, QudaChirality chirality) + { + spinorChiralReconstruct(dst, src, src, chirality); + } + + void spinorChiralReconstruct(ColorSpinorField &dst, const ColorSpinorField &src_left, const ColorSpinorField &src_right) + { + spinorChiralReconstruct(dst, src_left, src_right, QUDA_INVALID_CHIRALITY); + } + + template class SpinorChiralProject : TunableKernel2D + { + ColorSpinorField &out_left; + ColorSpinorField &out_right; + const ColorSpinorField ∈ + const QudaChirality chirality; + unsigned int minThreads() const { return in.VolumeCB(); } + + public: + SpinorChiralProject(ColorSpinorField &out_left, ColorSpinorField &out_right, const ColorSpinorField &in, + QudaChirality chirality) : + TunableKernel2D(in, in.SiteSubset()), out_left(out_left), out_right(out_right), in(in), chirality(chirality) + { + apply(device::get_default_stream()); + } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + if (chirality == QUDA_INVALID_CHIRALITY) { + ChiralProjectSpinorArg arg(out_left, out_right, in); + launch(tp, stream, arg); + } else if (chirality == QUDA_LEFT_CHIRALITY) { + ChiralProjectSpinorArg arg(out_left, out_right, in); + launch(tp, stream, arg); + } else if (chirality == QUDA_RIGHT_CHIRALITY) { + ChiralProjectSpinorArg arg(out_left, out_right, in); + launch(tp, stream, arg); + } else { + errorQuda("Unsupported chirality %d", chirality); + } + } + + long long bytes() const + { + return in.Bytes() + ((chirality != QUDA_RIGHT_CHIRALITY) ? out_left.Bytes() : 0) + + ((chirality != QUDA_LEFT_CHIRALITY) ? out_right.Bytes() : 0); + } + }; + + void spinorChiralProject(ColorSpinorField &dst_left, ColorSpinorField &dst_right, const ColorSpinorField &src, + QudaChirality chirality) + { + checkPrecision(dst_left, dst_right, src); + checkColor(dst_left, dst_right, src); + + if (dst_left.Nspin() != 2 || dst_right.Nspin() != 2 || src.Nspin() != 4) { + errorQuda("Unsupported nspin combination: dst_left=%d, dst_right=%d, src=%d\n", dst_left.Nspin(), + dst_right.Nspin(), src.Nspin()); + } + if (dst_left.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS + || dst_right.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS || src.GammaBasis() != QUDA_UKQCD_GAMMA_BASIS) { + errorQuda("Unsupported gamma basis combination: dst_left %d, dst_right %d, src %d\n", dst_left.GammaBasis(), + dst_right.GammaBasis(), src.GammaBasis()); + } + + if (src.Ncolor() == 3) { + if (src.Precision() == QUDA_DOUBLE_PRECISION) { + SpinorChiralProject(dst_left, dst_right, src, chirality); + } else if (src.Precision() == QUDA_SINGLE_PRECISION) { + SpinorChiralProject(dst_left, dst_right, src, chirality); + } else { + errorQuda("Precision %d not implemented", src.Precision()); + } + } else { + errorQuda("nColor=%d not implemented", src.Ncolor()); + } + } + + void spinorChiralProject(ColorSpinorField &dst, const ColorSpinorField &src, QudaChirality chirality) + { + spinorChiralProject(dst, dst, src, chirality); + } + + void spinorChiralProject(ColorSpinorField &dst_left, ColorSpinorField &dst_right, const ColorSpinorField &src) + { + spinorChiralProject(dst_left, dst_right, src, QUDA_INVALID_CHIRALITY); + } + +} // namespace quda