diff --git a/include/color_spinor_field.h b/include/color_spinor_field.h index 60521dba72..746eaefcb2 100644 --- a/include/color_spinor_field.h +++ b/include/color_spinor_field.h @@ -155,10 +155,10 @@ namespace quda QudaFieldLocation location = QUDA_CPU_FIELD_LOCATION) : LatticeFieldParam(4, X, 0, location, inv_param.cpu_prec), nColor(3), - nSpin((inv_param.dslash_type == QUDA_ASQTAD_DSLASH || inv_param.dslash_type == QUDA_STAGGERED_DSLASH - || inv_param.dslash_type == QUDA_LAPLACE_DSLASH) ? - 1 : - 4), + nSpin((inv_param.dslash_type == QUDA_LAPLACE_DSLASH) ? inv_param.laplace_nspin : + (inv_param.dslash_type == QUDA_COVDEV_DSLASH) ? inv_param.covdev_nspin : + (inv_param.dslash_type == QUDA_ASQTAD_DSLASH || inv_param.dslash_type == QUDA_STAGGERED_DSLASH) ? 1 : + 4), twistFlavor(inv_param.twist_flavor), gammaBasis(nSpin == 4 ? inv_param.gamma_basis : QUDA_DEGRAND_ROSSI_GAMMA_BASIS), create(QUDA_REFERENCE_FIELD_CREATE), diff --git a/include/dirac_quda.h b/include/dirac_quda.h index 7372c956e3..44f289577a 100644 --- a/include/dirac_quda.h +++ b/include/dirac_quda.h @@ -44,6 +44,7 @@ namespace quda { GaugeField *longGauge; // used by staggered only int laplace3D; int covdev_mu; + bool covdev_shift; CloverField *clover; GaugeField *xInvKD; // used for the Kahler-Dirac operator only @@ -111,6 +112,7 @@ namespace quda { printfQuda("mass = %g\n", mass); printfQuda("laplace3D = %d\n", laplace3D); printfQuda("covdev_mu = %d\n", covdev_mu); + printfQuda("covdev_shift = %d\n", covdev_shift); printfQuda("m5 = %g\n", m5); printfQuda("Ls = %d\n", Ls); printfQuda("matpcType = %d\n", matpcType); @@ -2249,6 +2251,7 @@ namespace quda { protected: int covdev_mu; + int covdev_shift; public: GaugeCovDev(const DiracParam ¶m); @@ -2262,6 +2265,11 @@ namespace quda { virtual void MCD(cvector_ref &out, cvector_ref &in, const int mu) const; virtual void MdagMCD(cvector_ref &out, cvector_ref &in, const int mu) const; + virtual void DslashS(cvector_ref &out, cvector_ref &in, QudaParity parity, + int mu) const; + virtual void MS(cvector_ref &out, cvector_ref &in, const int mu) const; + virtual void MdagMS(cvector_ref &out, cvector_ref &in, const int mu) const; + virtual void Dslash(cvector_ref &out, cvector_ref &in, QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, diff --git a/include/dslash_quda.h b/include/dslash_quda.h index f34a41de1a..b75e999724 100644 --- a/include/dslash_quda.h +++ b/include/dslash_quda.h @@ -793,11 +793,12 @@ namespace quda @param[in] mu Direction of the derivative. For mu > 3 it goes backwards @param[in] parity Destination parity @param[in] dagger Whether this is for the dagger operator + @param[in] shift Whether to apply the shift instead of the covariant derivative @param[in] comm_override Override for which dimensions are partitioned @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyCovDev(cvector_ref &out, cvector_ref &in, const GaugeField &U, - int mu, int parity, bool dagger, const int *comm_override, TimeProfile &profile); + int mu, int parity, bool dagger, bool shift, const int *comm_override, TimeProfile &profile); /** @brief Apply clover-matrix field to a color-spinor field diff --git a/include/gauge_tools.h b/include/gauge_tools.h index e600311a99..9676e20109 100644 --- a/include/gauge_tools.h +++ b/include/gauge_tools.h @@ -178,6 +178,34 @@ namespace quda void GFlowStep(GaugeField &out, GaugeField &temp, GaugeField &in, double epsilon, QudaGaugeSmearType smear_type, QudaWFlowStepType step_type); + /** + * @brief Rotate gauge field U_\mu(x) with rotation field g(x). + * U'_\mu(x) = g(x)U_\mu(x)g^\dagger(x+\mu) + * @param[in,out] out Rotated gauge field U'_\mu(x) + * @param[in] in Gauge field U_\mu(x) + * @param[in] rot Rotation field g(x) + */ + void gaugeRotate(GaugeField &out, const GaugeField &in, const GaugeField &rot); + + /** + * @brief Gauge fixing with over-relaxation. + * @param[in,out] rot Rotation field to fix the gauge + * @param[in] u Gauge field + * @param[in] omega The over-relaxation parameter, most common value is 1.5 or 1.7 + * @param[in] dir_ignore The ignored direction, 3 (Coulomb gauge) and 4 (Landau gauge) are common choices + */ + void gaugeFixOVRStep(GaugeField &rot, const GaugeField &u, double omega, int dir_ignore); + + /** + * @brief Compute the gauge fixing quality, functional or theta is considered as the criterion. + * @param[in,out] quality The functional and theta value + * @param[in] u Fixed gauge field + * @param[in] rot Rotation field + * @param[in] dir_ignore The ignored direction, 3 (Coulomb gauge) and 4 (Landau gauge) are common choices + * @param[in] compute_theta Set to true to compute the theta value as the criterion + */ + void gaugeFixQuality(double quality[2], const GaugeField &rot, const GaugeField &u, int dir_ignore, bool compute_theta); + /** * @brief Gauge fixing with overrelaxation with support for single and multi GPU. * @param[in,out] data, quda gauge field diff --git a/include/kernels/covariant_derivative.cuh b/include/kernels/covariant_derivative.cuh index 1a7f9d5706..b041a55c13 100644 --- a/include/kernels/covariant_derivative.cuh +++ b/include/kernels/covariant_derivative.cuh @@ -14,7 +14,7 @@ namespace quda /** @brief Parameter structure for driving the covariant derivative operator */ - template + template struct CovDevArg : DslashArg { static constexpr int nColor = nColor_; static constexpr int nSpin = nSpin_; @@ -29,6 +29,8 @@ namespace quda static constexpr QudaGhostExchange ghost = QUDA_GHOST_EXCHANGE_PAD; typedef typename gauge_mapper::type G; + static constexpr bool shift = shift_; + typedef typename mapper::type real; F out[MAX_MULTI_RHS]; /** output vector field */ @@ -82,25 +84,33 @@ namespace quda const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc); const bool ghost = (coord[d] + 1 >= arg.dc.X[d]) && isActive(active, thread_dim, d, coord, arg); - const Link U = arg.U(d, coord.x_cb, parity); - if (doHalo(d) && ghost) { const int ghost_idx = ghostFaceIndex<1>(coord, arg.dc.X, d, arg.nFace); const Vector in = arg.halo.Ghost(d, 1, ghost_idx + src_idx * arg.dc.ghostFaceCB[d], their_spinor_parity); - out += U * in; + if constexpr (Arg::shift) { + out += in; + } else { + const Link U = arg.U(d, coord.x_cb, parity); + out += U * in; + } } else if (doBulk() && !ghost) { const Vector in = arg.in[src_idx](fwd_idx, their_spinor_parity); - out += U * in; + + if constexpr (Arg::shift) { + out += in; + } else { + const Link U = arg.U(d, coord.x_cb, parity); + out += U * in; + } } } else if (mu >= 4 && arg.dd_in.doHopping(coord, d, -1)) { // Backward gather - compute back offset for spinor and gauge fetch const int back_idx = getNeighborIndexCB(coord, d, -1, arg.dc); - const int gauge_idx = back_idx; const bool ghost = (coord[d] - 1 < 0) && isActive(active, thread_dim, d, coord, arg); @@ -110,13 +120,23 @@ namespace quda const Link U = arg.U.Ghost(d, ghost_idx, 1 - parity); const Vector in = arg.halo.Ghost(d, 0, ghost_idx + src_idx * arg.dc.ghostFaceCB[d], their_spinor_parity); - out += conj(U) * in; + if constexpr (Arg::shift) { + out += in; + } else { + const Link U = arg.U.Ghost(d, ghost_idx, 1 - parity); + out += conj(U) * in; + } } else if (doBulk() && !ghost) { - const Link U = arg.U(d, gauge_idx, 1 - parity); const Vector in = arg.in[src_idx](back_idx, their_spinor_parity); - out += conj(U) * in; + if constexpr (Arg::shift) { + out += in; + } else { + const int gauge_idx = back_idx; + const Link U = Arg::shift ? Link() : arg.U(d, gauge_idx, 1 - parity); + out += conj(U) * in; + } } } // Forward/backward derivative } diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index 8b66ee83e6..955112bba9 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -40,7 +40,7 @@ namespace quda const G U; /** the gauge field */ const real a; /** xpay scale factor - can be -kappa or -kappa^2 */ /** parameters for distance preconditioning */ - const real alpha0; + const double alpha0; const int t0; WilsonArg(cvector_ref &out, cvector_ref &in, const ColorSpinorField &halo, @@ -88,10 +88,10 @@ namespace quda const int t = coord.gx[3]; const int nt = arg.globalDim3; - real fwd_coeff_3 - = Arg::distance_pc ? distanceWeight(arg, t + 1, nt) / distanceWeight(arg, t, nt) : static_cast(1.0); - real bwd_coeff_3 - = Arg::distance_pc ? distanceWeight(arg, t - 1, nt) / distanceWeight(arg, t, nt) : static_cast(1.0); + real fwd_coeff_3 = Arg::distance_pc ? static_cast(distanceWeight(arg, t + 1, nt) / distanceWeight(arg, t, nt)) : + static_cast(1.0); + real bwd_coeff_3 = Arg::distance_pc ? static_cast(distanceWeight(arg, t - 1, nt) / distanceWeight(arg, t, nt)) : + static_cast(1.0); #pragma unroll for (int d = 0; d < 4; d++) { // loop over dimension - 4 and not nDim since this is used for DWF as well diff --git a/include/kernels/gauge_fix_ovr2.cuh b/include/kernels/gauge_fix_ovr2.cuh new file mode 100644 index 0000000000..c846804823 --- /dev/null +++ b/include/kernels/gauge_fix_ovr2.cuh @@ -0,0 +1,186 @@ +#include +#include +#include +#include +#include + +namespace quda +{ + + template + struct FixGaugeArg : kernel_param<> { + using Float = Float_; + static constexpr int nColor = nColor_; + static_assert(nColor == 3, "Only nColor=3 enabled at this time"); + static constexpr QudaReconstructType recon = recon_; + static constexpr int parity = parity_; + static constexpr bool over_relaxation = over_relaxation_; + typedef typename gauge_mapper::type Gauge; + + Gauge rot; + const Gauge u; + + int X[4]; // grid dimensions + int border[4]; + const Float omega; + const int dir_ignore; + const Float tolerance; + + FixGaugeArg(GaugeField &rot, const GaugeField &u, double omega, int dir_ignore) : + kernel_param(dim3(u.LocalVolumeCB())), + rot(rot), + u(u), + omega(omega), + dir_ignore(dir_ignore), + tolerance(u.toleranceSU3()) + { + for (int dir = 0; dir < 4; ++dir) { + border[dir] = u.R()[dir]; + X[dir] = u.X()[dir] - border[dir] * 2; + } + } + }; + + /** + * @brief Maximize the real trace of UW in SU(2) subgroups and apply + * the result to U. Note it's equivalent to a closest unitary matrix + * problem in SU(2) subgroups. + * U' = \argmax_{U}\mathfrak{Re}\mathrm{Tr}(UV) = \argmin_{U}||U-V^\dagger||_F^2 + * Specifically, for GL(2,C) matrices, we have + * U = \frac{V^\dagger}{\sqrt{\det(V^\dagger)}} + * Also accepts the over-relaxation parameter for gauge fixing + */ + template + __host__ __device__ inline void argmaxReTrUW(Matrix, 3> &U, Matrix, 3> &W, const Arg &arg) + { + int i1, i2; + switch (su2_index) { + case 0: i1 = 0, i2 = 1; break; + case 1: i1 = 1, i2 = 2; break; + case 2: i1 = 0, i2 = 2; break; + default: break; + } + + Matrix, 3> V = U * W; + double versors[4]; // use double to avoid precision issues + + versors[0] = static_cast(V(i1, i1).real() + V(i2, i2).real()); + versors[1] = static_cast(V(i1, i1).imag() - V(i2, i2).imag()); + versors[2] = static_cast(V(i1, i2).real() - V(i2, i1).real()); + versors[3] = static_cast(V(i1, i2).imag() + V(i2, i1).imag()); + + double norm + = sqrt(versors[0] * versors[0] + versors[1] * versors[1] + versors[2] * versors[2] + versors[3] * versors[3]); + if (norm > arg.tolerance) { + double inv_norm = 1.0 / norm; + versors[0] *= inv_norm; +#pragma unroll + for (int i = 1; i < 4; ++i) { versors[i] *= -inv_norm; } + } else { + versors[0] = 1.0; +#pragma unroll + for (int i = 1; i < 4; ++i) { versors[i] = 0.0; } + } + + if constexpr (Arg::over_relaxation) { + // a workaround for fp32 numerical issues + // clamp versors[0] to [-1, 1] to avoid NaN in acos + // if (versors[0] > 1.0) { + // versors[0] = 1.0; + // } else if (versors[0] < -1.0) { + // versors[0] = -1.0; + // } + double angle = acos(versors[0]); + double sin_angle = sin(angle); + double sin_omega_angle, cos_omega_angle; + sincos(arg.omega * angle, &sin_omega_angle, &cos_omega_angle); + versors[0] = cos_omega_angle; + if (sin_angle > arg.tolerance) { + double coeff = sin_omega_angle / sin_angle; +#pragma unroll + for (int i = 1; i < 4; ++i) { versors[i] *= coeff; } + } else { +#pragma unroll + for (int i = 1; i < 4; ++i) { versors[i] = 0.0; } + } + } + + setIdentity(&V); + V(i1, i1) = complex(static_cast(versors[0]), static_cast(versors[1])); + V(i2, i2) = complex(static_cast(versors[0]), static_cast(-versors[1])); + V(i1, i2) = complex(static_cast(versors[2]), static_cast(versors[3])); + V(i2, i1) = complex(static_cast(-versors[2]), static_cast(versors[3])); + + U = V * U; + } + + // /** + // * @brief Solving the closest unitary matrix in SU(2) subgroups. + // * This is another way to project the input matrix on the SU(3) group. + // */ + // template + // __host__ __device__ inline void closestSu3(Matrix,3> &in, Float tol) + // { + // Matrix, 3> out; + // setIdentity(&out); + + // constexpr int max_iter = 100; + // int i = 0; + // Float old_retr, retr = getTrace(in).real(); + // do { // iterate until matrix is unitary + // // loop over SU(2) subgroup indices + // argmaxReTrUW<0, Float>(out, in); + // argmaxReTrUW<1, Float>(out, in); + // argmaxReTrUW<2, Float>(out, in); + // old_retr = retr; + // retr = getTrace(out * in).real(); + // } while (abs(retr - old_retr) / old_retr > tol && ++i < max_iter); + + // in = out; + // } + + template struct FixGauge { + const Arg &arg; + constexpr FixGauge(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ inline void operator()(int x_cb) + { + using real = typename Arg::Float; + typedef Matrix, Arg::nColor> Link; + constexpr int parity = Arg::parity; + + // compute spacetime and local coords + int x[4], X[4]; +#pragma unroll + for (int dr = 0; dr < 4; ++dr) { X[dr] = arg.X[dr]; } + getCoords(x, x_cb, X, parity); +#pragma unroll + for (int dr = 0; dr < 4; ++dr) { + x[dr] += arg.border[dr]; + X[dr] += 2 * arg.border[dr]; + } + + Link g, K, tmp, tmp2; + g = arg.rot(0, linkIndex(x, X), parity); +#pragma unroll + for (int dir = 0; dir < 4; ++dir) { + if (dir != arg.dir_ignore) { + tmp = arg.u(dir, linkIndex(x, X), parity); + tmp2 = arg.rot(0, linkIndexP1(x, X, dir), 1 - parity); + K += tmp * conj(tmp2); + tmp = arg.u(dir, linkIndexM1(x, X, dir), 1 - parity); + tmp2 = arg.rot(0, linkIndexM1(x, X, dir), 1 - parity); + K += conj(tmp2 * tmp); + } + } + + // loop over SU(2) subgroup indices + argmaxReTrUW<0, real>(g, K, arg); + argmaxReTrUW<1, real>(g, K, arg); + argmaxReTrUW<2, real>(g, K, arg); + + arg.rot(0, linkIndex(x, X), parity) = g; + } + }; +} // namespace quda diff --git a/include/kernels/gauge_fix_quality.cuh b/include/kernels/gauge_fix_quality.cuh new file mode 100644 index 0000000000..e603b891a3 --- /dev/null +++ b/include/kernels/gauge_fix_quality.cuh @@ -0,0 +1,93 @@ +#include +#include +#include +#include +#include +#include + +namespace quda +{ + + template + struct GaugeFixQualityArg : public ReduceArg> { + using Float = Float_; + static constexpr int nColor = nColor_; + static_assert(nColor == 3, "Only nColor=3 enabled at this time"); + static constexpr QudaReconstructType recon = recon_; + static constexpr bool compute_theta = compute_theta_; + typedef typename gauge_mapper::type Gauge; + + const Gauge u; + const Gauge rot; + + int X[4]; // grid dimensions + int border[4]; + const int dir_ignore; + + GaugeFixQualityArg(const GaugeField &u, const GaugeField &rot, int dir_ignore) : + ReduceArg(dim3(u.LocalVolumeCB(), 2)), u(u), rot(rot), dir_ignore(dir_ignore) + { + for (int dir = 0; dir < 4; ++dir) { + border[dir] = u.R()[dir]; + X[dir] = u.X()[dir] - border[dir] * 2; + } + } + }; + + template struct GaugeFixQuality : plus { + using reduce_t = typename Arg::reduce_t; + using plus::operator(); + static constexpr int reduce_block_dim = 2; // x_cb in x, parity in y + const Arg &arg; + constexpr GaugeFixQuality(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ inline reduce_t operator()(reduce_t &value, int x_cb, int parity) + { + reduce_t quality {0, 0}; + + using real = typename Arg::Float; + typedef Matrix, Arg::nColor> Link; + + // compute spacetime and local coords + int x[4], X[4]; +#pragma unroll + for (int dr = 0; dr < 4; ++dr) X[dr] = arg.X[dr]; + getCoords(x, x_cb, X, parity); +#pragma unroll + for (int dr = 0; dr < 4; ++dr) { + x[dr] += arg.border[dr]; + X[dr] += 2 * arg.border[dr]; + } + + Link g0, g, U, V; + g0 = arg.rot(0, linkIndex(x, X), parity); +#pragma unroll + for (int dir = 0; dir < 4; ++dir) { + if (dir != arg.dir_ignore) { + g = arg.rot(0, linkIndexP1(x, X, dir), 1 - parity); + U = arg.u(dir, linkIndex(x, X), parity); + V += U * conj(g); + } + } + quality[0] = getTrace(g0 * V).real(); + + if constexpr (Arg::compute_theta) { +#pragma unroll + for (int dir = 0; dir < 4; ++dir) { + if (dir != arg.dir_ignore) { + g = arg.rot(0, linkIndexM1(x, X, dir), 1 - parity); + U = arg.u(dir, linkIndexM1(x, X, dir), 1 - parity); + V += conj(g * U); + } + } + V = g0 * V; + V -= conj(V); + SubTraceUnit(V); + quality[1] = getRealTraceUVdagger(V, V); + } + + return operator()(quality, value); + } + }; +} // namespace quda diff --git a/include/kernels/gauge_rotate.cuh b/include/kernels/gauge_rotate.cuh new file mode 100644 index 0000000000..90263f2e49 --- /dev/null +++ b/include/kernels/gauge_rotate.cuh @@ -0,0 +1,65 @@ +#include +#include +#include +#include +#include +#include + +namespace quda +{ + template struct RotateGaugeArg : kernel_param<> { + using Float = Float_; + static constexpr int nColor = nColor_; + static_assert(nColor == 3, "Only nColor=3 enabled at this time"); + static constexpr QudaReconstructType recon = recon_; + typedef typename gauge_mapper::type Gauge; + + Gauge out; + const Gauge in; + const Gauge rot; + + int X[4]; // grid dimensions + int border[4]; + + RotateGaugeArg(GaugeField &out, const GaugeField &in, const GaugeField &rot) : + kernel_param(dim3(in.LocalVolumeCB(), 2, 4)), out(out), in(in), rot(rot) + { + for (int dir = 0; dir < 4; ++dir) { + border[dir] = rot.R()[dir]; + X[dir] = rot.X()[dir] - border[dir] * 2; + } + } + }; + + template struct RotateGauge { + const Arg &arg; + constexpr RotateGauge(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ inline void operator()(int x_cb, int parity, int dir) + { + using real = typename Arg::Float; + typedef Matrix, Arg::nColor> Link; + + // compute spacetime and local coords + int x[4], X[4]; +#pragma unroll + for (int dr = 0; dr < 4; ++dr) { X[dr] = arg.X[dr]; } + getCoords(x, x_cb, X, parity); +#pragma unroll + for (int dr = 0; dr < 4; ++dr) { + x[dr] += arg.border[dr]; + X[dr] += 2 * arg.border[dr]; + } + + Link g, U; + U = arg.in(dir, x_cb, parity); + g = arg.rot(0, linkIndex(x, X), parity); + U = g * U; + g = arg.rot(0, linkIndexP1(x, X, dir), 1 - parity); + U = U * conj(g); + + arg.out(dir, x_cb, parity) = U; + } + }; +} // namespace quda diff --git a/include/kernels/pgauge_init.cuh b/include/kernels/pgauge_init.cuh index 583c567d6f..b9456f889d 100644 --- a/include/kernels/pgauge_init.cuh +++ b/include/kernels/pgauge_init.cuh @@ -13,10 +13,11 @@ namespace quda { - template + template struct InitGaugeColdArg : kernel_param<> { static constexpr int nColor = nColor_; static constexpr QudaReconstructType recon = recon_; + static constexpr QudaFieldGeometry geom = geom_; using real = typename mapper::type; using Gauge = typename gauge_mapper::type; int X[4]; // grid dimensions @@ -38,14 +39,15 @@ namespace quda { { Matrix, Arg::nColor> U; setIdentity(&U); - for ( int d = 0; d < 4; d++ ) arg.U(d, x_cb, parity) = U; + for (int d = 0; d < Arg::geom; d++) arg.U(d, x_cb, parity) = U; } }; - template + template struct InitGaugeHotArg : kernel_param<> { static constexpr int nColor = nColor_; static constexpr QudaReconstructType recon = recon_; + static constexpr QudaFieldGeometry geom = geom_; using real = typename mapper::type; using Gauge = typename gauge_mapper::type; int X[4]; // grid dimensions @@ -213,7 +215,7 @@ namespace quda { getCoords(x, x_cb, arg.X, parity); for (int dr = 0; dr < 4; dr++) x[dr] += arg.border[dr]; int e_cb = linkIndex(x, X); - for (int d = 0; d < 4; d++) { + for (int d = 0; d < Arg::geom; d++) { Matrix, Arg::nColor> U; U = randomize(localState); arg.U(d, e_cb, parity) = U; diff --git a/include/kernels/spinor_reweight.cuh b/include/kernels/spinor_reweight.cuh index 3315e3b64d..e1ef91ab08 100644 --- a/include/kernels/spinor_reweight.cuh +++ b/include/kernels/spinor_reweight.cuh @@ -16,9 +16,9 @@ namespace quda int X[4]; V v; - real alpha0; + double alpha0; int t0; - SpinorDistanceReweightArg(ColorSpinorField &v, real alpha0, int t0) : + SpinorDistanceReweightArg(ColorSpinorField &v, double alpha0, int t0) : kernel_param(dim3(v.VolumeCB(), v.SiteSubset(), 1)), v(v), alpha0(alpha0), t0(t0) { for (int dir = 0; dir < 4; dir++) X[dir] = v.X()[dir]; @@ -26,13 +26,12 @@ namespace quda } }; - template __device__ __host__ inline auto distanceWeight(const Arg &arg, int t, int nt) + template __device__ __host__ inline double distanceWeight(const Arg &arg, int t, int nt) { - using real = typename Arg::real; if (arg.alpha0 > 0) { - return cosh(arg.alpha0 * real((t - arg.t0 + nt) % nt - nt / 2)); + return cosh(arg.alpha0 * double((t - arg.t0 + nt) % nt - nt / 2)); } else { - return 1 / cosh(arg.alpha0 * real((t - arg.t0 + nt) % nt - nt / 2)); + return 1 / cosh(arg.alpha0 * double((t - arg.t0 + nt) % nt - nt / 2)); } } @@ -47,7 +46,8 @@ namespace quda int x[4]; getCoords(x, x_cb, arg.X, parity); Vector tmp = arg.v(x_cb, parity); - tmp *= distanceWeight(arg, arg.comms_coord[3] * arg.X[3] + x[3], arg.comms_dim[3] * arg.X[3]); + tmp *= static_cast( + distanceWeight(arg, arg.comms_coord[3] * arg.X[3] + x[3], arg.comms_dim[3] * arg.X[3])); arg.v(x_cb, parity) = tmp; } }; diff --git a/include/kernels/unitarize_links.cuh b/include/kernels/unitarize_links.cuh index 5c6745b520..dfbde398e0 100644 --- a/include/kernels/unitarize_links.cuh +++ b/include/kernels/unitarize_links.cuh @@ -30,10 +30,9 @@ namespace quda { const double svd_abs_error; const static bool check_unitarization = true; - UnitarizeArg(GaugeField &out, const GaugeField &in, int* fails, int max_iter, - double unitarize_eps, double max_error, int reunit_allow_svd, - int reunit_svd_only, double svd_rel_error, double svd_abs_error) : - kernel_param(dim3(in.VolumeCB(), 2, 4)), + UnitarizeArg(GaugeField &out, const GaugeField &in, int *fails, int max_iter, double unitarize_eps, double max_error, + int reunit_allow_svd, int reunit_svd_only, double svd_rel_error, double svd_abs_error) : + kernel_param(dim3(in.VolumeCB(), 2, in.Geometry())), out(out), in(in), fails(fails), diff --git a/include/quda.h b/include/quda.h index 0bd715c5f4..830f3dbb24 100644 --- a/include/quda.h +++ b/include/quda.h @@ -62,6 +62,8 @@ extern "C" { QudaReconstructType reconstruct_eigensolver; /**< The recontruction type of the eigensolver gauge field */ QudaGaugeFixed gauge_fix; /**< Whether the input gauge field is in the axial gauge or not */ + QudaBoolean gauge_fix_compute_theta; /**< Compute theta in the gauge fixing algorithm */ + QudaBoolean gauge_fix_use_theta; /**< Use theta as the criterion in the gauge fixing algorithm */ int ga_pad; /**< The pad size that native GaugeFields will use (default=0) */ @@ -139,8 +141,11 @@ extern "C" { QudaTwistFlavorType twist_flavor; /**< Twisted mass flavor */ - int laplace3D; /**< omit this direction from laplace operator: x,y,z,t -> 0,1,2,3 (-1 is full 4D) */ - int covdev_mu; /**< Apply forward/backward covariant derivative in direction mu(mu<=3)/mu-4(mu>3) */ + int laplace3D; /**< omit this direction from laplace operator: x,y,z,t -> 0,1,2,3 (-1 is full 4D) */ + int laplace_nspin; /**< Number of spin for the Laplace operator */ + int covdev_mu; /**< Apply forward/backward covariant derivative in direction mu(mu<=3)/mu-4(mu>3) */ + bool covdev_shift; /**< Apply the shift instead of the covariant derivative */ + int covdev_nspin; /**< Number of spin for the covariant derivative operator */ double tol; /**< Solver tolerance in the L2 residual norm */ double tol_restart; /**< Solver tolerance in the L2 residual norm (used to restart InitCG) */ @@ -885,6 +890,19 @@ extern "C" { A negative value means 3D for APE/STOUT and 4D for OVRIMP_STOUT/HYP */ } QudaGaugeSmearParam; + typedef struct QudaGaugeFixParam_s { + size_t struct_size; /**< Size of this struct in bytes. Used to ensure that the host application and QUDA see the same struct */ + double tol; /**< Tolerance of the fixing */ + int maxiter; /**< The maximum number of fixing steps to perform. */ + int dir_ignore; /**< The direction to be ignored by the fixing algorithm */ + double omega; /**< Parameter used for OVR algorithm */ + double alpha; /**< Parameter used for FFT algorithm */ + int reunit_interval; /**< Interval at which to reunitarize the rotation field */ + int verbose_interval; /**< Interval at which to print the gauge fixing progress */ + bool compute_theta; /**< Compute theta as the quality metric */ + bool use_theta; /**< Use theta instead of relative difference of functional as the criterion */ + } QudaGaugeFixParam; + typedef struct QudaBLASParam_s { size_t struct_size; /**< Size of this struct in bytes. Used to ensure that the host application and QUDA see the same struct*/ @@ -1094,6 +1112,15 @@ extern "C" { */ QudaGaugeSmearParam newQudaGaugeSmearParam(void); + /** + * A new QudaGaugeFixParam should always be initialized + * immediately after it's defined (and prior to explicitly setting + * its members) using this function. Typical usage is as follows: + * + * QudaGaugeFixParam fix_param = newQudaGaugeFixParam(); + */ + QudaGaugeFixParam newQudaGaugeFixParam(void); + /** * A new QudaBLASParam should always be initialized immediately * after it's defined (and prior to explicitly setting its members) @@ -1133,6 +1160,18 @@ extern "C" { */ void printQudaGaugeObservableParam(QudaGaugeObservableParam *param); + /** + * Print the members of QudaGaugeSmearParam. + * @param param The QudaGaugeSmearParam whose elements we are to print. + */ + void printQudaGaugeSmearParam(QudaGaugeSmearParam *param); + + /** + * Print the members of QudaGaugeFixParam. + * @param param The QudaGaugeFixParam whose elements we are to print. + */ + void printQudaGaugeFixParam(QudaGaugeFixParam *param); + /** * Print the members of QudaBLASParam. * @param param The QudaBLASParam whose elements we are to print. @@ -1779,6 +1818,25 @@ extern "C" { const int src_colors, const int *X, const int *const source_position, const int n_mom, const int *const mom_modes, const QudaFFTSymmType *const fft_type); + /** + * @brief Rotate gauge field U_\mu(x) with the rotation field g(x) + * U'_\mu(x) = g(x) U_\mu(x) g^\dagger(x+\mu) + * @param[in] rotation Rotation field g(x) to rotate the gauge + * @param[in,out] gauge Gauge field U_\mu(x) to be rotated + * @param[in] param Parameters of the external fields + */ + void performGaugeRotateQuda(void *rotation, void *gauge, QudaGaugeParam *param); + + /** + * @brief Gauge fixing with over-relaxation. + * U'_\mu(x) = g(x) U_\mu(x) g^\dagger(x+\mu) + * @param[in,out] rotation Rotation field g(x) to fix the gauge + * @param[in,out] gauge Gauge field U_\mu(x) to be fixed + * @param[in] param Parameters of the external fields + * @param[in] fix_param Parameters of the gauge fixing algorithm + */ + void performGaugeFixQuda(void *rotation, void *gauge, QudaGaugeParam *param, QudaGaugeFixParam *fix_param); + /** * @brief Gauge fixing with overrelaxation with support for single and multi GPU. * @param[in,out] gauge, gauge field to be fixed diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 6043cc0250..f4a4186e3a 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -106,7 +106,8 @@ set (QUDA_OBJS device_vector.cu inv_gmresdr_quda.cpp pgauge_exchange.cu pgauge_init.cu pgauge_heatbath.cu random.cu - gauge_fix_fft.cu gauge_fix_ovr.cu pgauge_det_trace.cu clover_outer_product.cu + gauge_fix_fft.cu gauge_fix_ovr.cu gauge_fix_ovr2.cu gauge_fix_quality.cu + gauge_rotate.cu pgauge_det_trace.cu clover_outer_product.cu clover_sigma_outer_product.cu momentum.cu gauge_qcharge.cu deflation.cpp checksum.cu transform_reduce.cu dslash5_mobius_eofa.cu diff --git a/lib/check_params.h b/lib/check_params.h index 6dbdc91dfd..5048a46b8a 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -106,6 +106,14 @@ void printQudaGaugeParam(QudaGaugeParam *param) { P(gauge_fix, QUDA_GAUGE_FIXED_INVALID); P(ga_pad, INVALID_INT); +#ifdef INIT_PARAM + P(gauge_fix_compute_theta, QUDA_BOOLEAN_TRUE); + P(gauge_fix_use_theta, QUDA_BOOLEAN_FALSE); +#else + P(gauge_fix_compute_theta, QUDA_BOOLEAN_INVALID); + P(gauge_fix_use_theta, QUDA_BOOLEAN_INVALID); +#endif + #if defined INIT_PARAM P(staggered_phase_type, QUDA_STAGGERED_PHASE_NO); P(staggered_phase_applied, 0); @@ -375,7 +383,10 @@ void printQudaInvertParam(QudaInvertParam *param) { P(tm_rho, 0.0); P(twist_flavor, QUDA_TWIST_INVALID); P(laplace3D, INVALID_INT); + P(laplace_nspin, 1); P(covdev_mu, INVALID_INT); + P(covdev_shift, false); + P(covdev_nspin, 4); #else // asqtad and domain wall use mass parameterization if (param->dslash_type == QUDA_STAGGERED_DSLASH || param->dslash_type == QUDA_ASQTAD_DSLASH @@ -1207,6 +1218,50 @@ void printQudaGaugeSmearParam(QudaGaugeSmearParam *param) #endif } +#if defined INIT_PARAM +QudaGaugeFixParam newQudaGaugeFixParam(void) +{ + QudaGaugeFixParam ret; +#elif defined CHECK_PARAM +static void checkGaugeFixParam(QudaGaugeFixParam *param) +{ +#else +void printQudaGaugeFixParam(QudaGaugeFixParam *param) +{ + printfQuda("QUDA Gauge Fix Parameters:\n"); +#endif + +#if defined CHECK_PARAM + if (param->struct_size != (size_t)INVALID_INT && param->struct_size != sizeof(*param)) + errorQuda("Unexpected QudaGaugeFixParam struct size %lu, expected %lu", param->struct_size, sizeof(*param)); +#else + P(struct_size, (size_t)INVALID_INT); +#endif + + // P(fix_type, QUDA_GAUGE_FIX_INVALID); + P(tol, INVALID_DOUBLE); + P(maxiter, INVALID_INT); + P(dir_ignore, INVALID_INT); + +#ifdef INIT_PARAM + P(omega, 1.0); + P(alpha, 0.0); + P(reunit_interval, 10); + P(verbose_interval, 100); + P(compute_theta, true); + P(use_theta, false); +#else + P(omega, INVALID_DOUBLE); + P(alpha, INVALID_DOUBLE); + P(reunit_interval, INVALID_INT); + P(verbose_interval, INVALID_INT); +#endif + +#ifdef INIT_PARAM + return ret; +#endif +} + #if defined INIT_PARAM QudaBLASParam newQudaBLASParam(void) { diff --git a/lib/covariant_derivative.cu b/lib/covariant_derivative.cu index df7767cdd1..9d62a67b72 100644 --- a/lib/covariant_derivative.cu +++ b/lib/covariant_derivative.cu @@ -47,7 +47,7 @@ namespace quda long long flops() const override { - int mv_flops = (8 * in.Ncolor() - 2) * in.Ncolor(); // SU(3) matrix-vector flops + int mv_flops = Arg::shift ? 0 : (8 * in.Ncolor() - 2) * in.Ncolor(); // SU(3) matrix-vector flops int num_mv_multiply = in.Nspin(); int ghost_flops = num_mv_multiply * mv_flops; int dim = arg.mu % 4; @@ -86,7 +86,7 @@ namespace quda long long bytes() const override { - int gauge_bytes = arg.reconstruct * in.Precision(); + int gauge_bytes = Arg::shift ? 0 : arg.reconstruct * in.Precision(); int spinor_bytes = 2 * in.Ncolor() * in.Nspin() * in.Precision() + (isFixed::value ? sizeof(float) : 0); int ghost_bytes = gauge_bytes + 3 * spinor_bytes; // 3 since we have to load the partial @@ -128,6 +128,7 @@ namespace quda auto key = Dslash::tuneKey(); strcat(key.aux, ",mu="); u32toa(key.aux + strlen(key.aux), arg.mu); + if constexpr (Arg::shift) { strcat(key.aux, ",shift"); } return key; } }; @@ -135,20 +136,32 @@ namespace quda template struct CovDevApply { CovDevApply(cvector_ref &out, cvector_ref &in, - cvector_ref &, const GaugeField &U, int mu, int parity, bool dagger, + cvector_ref &, const GaugeField &U, int mu, int parity, bool dagger, bool shift, const int *comm_override, TimeProfile &profile) { constexpr int nDim = 4; auto halo = ColorSpinorField::create_comms_batch(in, 1, false); if (in.Nspin() == 4) { - CovDevArg arg(out, in, halo, U, mu, parity, dagger, comm_override); - CovDev covDev(arg, out, in, halo); - dslash::DslashPolicyTune policy(covDev, out, in, halo, profile); + if (shift) { + CovDevArg arg(out, in, halo, U, mu, parity, dagger, comm_override); + CovDev covDev(arg, out, in, halo); + dslash::DslashPolicyTune policy(covDev, out, in, halo, profile); + } else { + CovDevArg arg(out, in, halo, U, mu, parity, dagger, comm_override); + CovDev covDev(arg, out, in, halo); + dslash::DslashPolicyTune policy(covDev, out, in, halo, profile); + } } else if (in.Nspin() == 1) { - CovDevArg arg(out, in, halo, U, mu, parity, dagger, comm_override); - CovDev covDev(arg, out, in, halo); - dslash::DslashPolicyTune policy(covDev, out, in, halo, profile); + if (shift) { + CovDevArg arg(out, in, halo, U, mu, parity, dagger, comm_override); + CovDev covDev(arg, out, in, halo); + dslash::DslashPolicyTune policy(covDev, out, in, halo, profile); + } else { + CovDevArg arg(out, in, halo, U, mu, parity, dagger, comm_override); + CovDev covDev(arg, out, in, halo); + dslash::DslashPolicyTune policy(covDev, out, in, halo, profile); + } } else { errorQuda("Spin not supported"); } @@ -156,13 +169,13 @@ namespace quda }; // Apply the covariant derivative operator - // out(x) = U_{\mu}(x)in(x+mu) for mu = 0...3 - // out(x) = U^\dagger_mu'(x-mu')in(x-mu') for mu = 4...7 and we set mu' = mu-4 + // out(x) = U_{\mu}(x)in(x+\mu) for mu = 0...3 + // out(x) = U^\dagger_{\mu'}(x-\mu')in(x-\mu') for mu = 4...7 and we set mu' = mu-4 void ApplyCovDev(cvector_ref &out, cvector_ref &in, const GaugeField &U, - int mu, int parity, bool dagger, const int *comm_override, TimeProfile &profile) + int mu, int parity, bool dagger, bool shift, const int *comm_override, TimeProfile &profile) { if constexpr (is_enabled()) { - instantiate(out, in, in, U, mu, parity, dagger, comm_override, profile); + instantiate(out, in, in, U, mu, parity, dagger, shift, comm_override, profile); } else { errorQuda("Covariant derivative kernels have not been built"); } diff --git a/lib/gauge_covdev.cpp b/lib/gauge_covdev.cpp index 611fcd569d..60137bfb43 100644 --- a/lib/gauge_covdev.cpp +++ b/lib/gauge_covdev.cpp @@ -6,9 +6,15 @@ namespace quda { - GaugeCovDev::GaugeCovDev(const DiracParam ¶m) : Dirac(param), covdev_mu(param.covdev_mu) { } + GaugeCovDev::GaugeCovDev(const DiracParam ¶m) : + Dirac(param), covdev_mu(param.covdev_mu), covdev_shift(param.covdev_shift) + { + } - GaugeCovDev::GaugeCovDev(const GaugeCovDev &covDev) : Dirac(covDev), covdev_mu(covDev.covdev_mu) { } + GaugeCovDev::GaugeCovDev(const GaugeCovDev &covDev) : + Dirac(covDev), covdev_mu(covDev.covdev_mu), covdev_shift(covDev.covdev_shift) + { + } GaugeCovDev::~GaugeCovDev() { } @@ -16,6 +22,7 @@ namespace quda { { if (&covDev != this) Dirac::operator=(covDev); covdev_mu = covDev.covdev_mu; + covdev_shift = covDev.covdev_shift; return *this; } @@ -27,7 +34,7 @@ namespace quda { int comm_dim[4] = {}; // only switch on comms needed for mu derivative (FIXME - only communicate in the given direction) comm_dim[mu % 4] = comm_dim_partitioned(mu % 4); - ApplyCovDev(out, in, *gauge, mu, parity, dagger, comm_dim, profile); + ApplyCovDev(out, in, *gauge, mu, parity, dagger, false, comm_dim, profile); } void GaugeCovDev::MCD(cvector_ref &out, cvector_ref &in, const int mu) const @@ -43,10 +50,39 @@ namespace quda { MCD(out, tmp, (mu + 4) % 8); } + void GaugeCovDev::DslashS(cvector_ref &out, cvector_ref &in, + QudaParity parity, int mu) const + { + checkSpinorAlias(in, out); + + int comm_dim[4] = {}; + // only switch on comms needed for mu derivative (FIXME - only communicate in the given direction) + comm_dim[mu % 4] = comm_dim_partitioned(mu % 4); + ApplyCovDev(out, in, *gauge, mu, parity, dagger, true, comm_dim, profile); + } + + void GaugeCovDev::MS(cvector_ref &out, cvector_ref &in, const int mu) const + { + checkFullSpinor(out, in); + DslashS(out, in, QUDA_INVALID_PARITY, mu); + } + + void GaugeCovDev::MdagMS(cvector_ref &out, cvector_ref &in, const int mu) const + { + auto tmp = getFieldTmp(out); + + MS(tmp, in, mu); + MS(out, tmp, (mu + 4) % 8); + } + void GaugeCovDev::Dslash(cvector_ref &out, cvector_ref &in, QudaParity parity) const { - DslashCD(out, in, parity, covdev_mu); + if (covdev_shift) { + DslashS(out, in, parity, covdev_mu); + } else { + DslashCD(out, in, parity, covdev_mu); + } } void GaugeCovDev::DslashXpay(cvector_ref &, cvector_ref &, QudaParity, @@ -57,12 +93,20 @@ namespace quda { void GaugeCovDev::M(cvector_ref &out, cvector_ref &in) const { - MCD(out, in, covdev_mu); + if (covdev_shift) { + MS(out, in, covdev_mu); + } else { + MCD(out, in, covdev_mu); + } } void GaugeCovDev::MdagM(cvector_ref &out, cvector_ref &in) const { - MdagMCD(out, in, covdev_mu); + if (covdev_shift) { + MdagMS(out, in, covdev_mu); + } else { + MdagMCD(out, in, covdev_mu); + } } void GaugeCovDev::prepare(cvector_ref &, cvector_ref &, diff --git a/lib/gauge_fix_ovr2.cu b/lib/gauge_fix_ovr2.cu new file mode 100644 index 0000000000..334d32bb01 --- /dev/null +++ b/lib/gauge_fix_ovr2.cu @@ -0,0 +1,100 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace quda +{ + + template class GaugeFix : TunableKernel1D + { + GaugeField &rot; + const GaugeField &u; + const Float omega; + const int dir_ignore; + const int fixDim; + const int parity; + unsigned int minThreads() const { return u.LocalVolumeCB(); } + + public: + GaugeFix(GaugeField &rot, const GaugeField &u, double omega, int dir_ignore, int parity) : + TunableKernel1D(u), + rot(rot), + u(u), + omega(static_cast(omega)), + dir_ignore(dir_ignore), + fixDim((dir_ignore == 4) ? 4 : 3), + parity(parity) + { + strcat(aux, ",dir_ignore="); + i32toa(aux + strlen(aux), dir_ignore); + strcat(aux, ",parity="); + i32toa(aux + strlen(aux), parity); + if (omega != 1.0) { strcat(aux, ",over_relaxation"); } + strcat(aux, comm_dim_partitioned_string()); + apply(device::get_default_stream()); + } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + if (omega == 1.0) { + if (parity == 0) { + FixGaugeArg arg(rot, u, omega, dir_ignore); + launch(tp, stream, arg); + } else if (parity == 1) { + FixGaugeArg arg(rot, u, omega, dir_ignore); + launch(tp, stream, arg); + } + } else { + if (parity == 0) { + FixGaugeArg arg(rot, u, omega, dir_ignore); + launch(tp, stream, arg); + } else if (parity == 1) { + FixGaugeArg arg(rot, u, omega, dir_ignore); + launch(tp, stream, arg); + } + } + } + + void preTune() { rot.backup(); } // defensive measure in case they alias + void postTune() { rot.restore(); } + + long long flops() const + { + auto mat_flops = u.Ncolor() * u.Ncolor() * (8ll * u.Ncolor() - 2ll); + return (fixDim * 2 + 2 * 3) * mat_flops * u.LocalVolumeCB(); + } + + long long bytes() const // 2 links per dim, 2 rot in per dim, 1 rot in, 1 rot out. + { + return ((fixDim * 2) * u.Reconstruct() * u.Precision() + (1 + fixDim * 2 + 1) * rot.Reconstruct() * rot.Precision()) + * u.LocalVolumeCB(); + } + + }; // GaugeFix + + void gaugeFixOVRStep(GaugeField &rot, const GaugeField &u, double omega, int dir_ignore) + { + checkPrecision(rot, u); + checkReconstruct(rot, u); + checkNative(rot, u); + + if (dir_ignore < 0 || dir_ignore > 3) { dir_ignore = 4; } + + // loop over parity + getProfile().TPSTART(QUDA_PROFILE_COMPUTE); + instantiate(rot, u, omega, dir_ignore, 0); + getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); + rot.exchangeExtendedGhost(rot.R(), getProfile(), false); + getProfile().TPSTART(QUDA_PROFILE_COMPUTE); + instantiate(rot, u, omega, dir_ignore, 1); + getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); + rot.exchangeExtendedGhost(rot.R(), getProfile(), false); + } + +} // namespace quda diff --git a/lib/gauge_fix_quality.cu b/lib/gauge_fix_quality.cu new file mode 100644 index 0000000000..22f27f7f22 --- /dev/null +++ b/lib/gauge_fix_quality.cu @@ -0,0 +1,78 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace quda +{ + + template class GaugeFixingQuality : TunableReduction2D + { + const GaugeField &u; + const GaugeField &rot; + double *quality; + const int dir_ignore; + const int fixDim; + const bool compute_theta; + unsigned int minThreads() const { return u.LocalVolumeCB(); } + + public: + GaugeFixingQuality(const GaugeField &u, const GaugeField &rot, double quality[2], int dir_ignore, bool compute_theta) : + TunableReduction2D(u, 2), + u(u), + rot(rot), + quality(quality), + dir_ignore(dir_ignore), + fixDim((dir_ignore == 4) ? 4 : 3), + compute_theta(compute_theta) + { + strcat(aux, ",dir_ignore="); + i32toa(aux + strlen(aux), dir_ignore); + if (compute_theta) { strcat(aux, ",compute_theta"); } + strcat(aux, comm_dim_partitioned_string()); + apply(device::get_default_stream()); + } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + array value {}; + if (compute_theta) { + GaugeFixQualityArg arg(u, rot, dir_ignore); + launch(value, tp, stream, arg); + } else { + GaugeFixQualityArg arg(u, rot, dir_ignore); + launch(value, tp, stream, arg); + } + quality[0] = value[0] / static_cast(fixDim * u.Ncolor() * u.LocalVolume() * comm_size()); + quality[1] = value[1] / static_cast(u.Ncolor() * u.LocalVolume() * comm_size()); + } + + long long flops() const { return 0; } + long long bytes() const + { + return ((compute_theta ? 2 : 1) * fixDim * u.Reconstruct() * u.Precision() + + (1 + (compute_theta ? 2 : 1) * fixDim) * rot.Reconstruct() * rot.Precision()) + * u.LocalVolume(); + } + + }; // GaugeFixingQuality + + void gaugeFixQuality(double quality[2], const GaugeField &rot, const GaugeField &u, int dir_ignore, bool compute_theta) + { + checkPrecision(rot, u); + checkReconstruct(rot, u); + checkNative(rot, u); + + if (dir_ignore < 0 || dir_ignore > 3) { dir_ignore = 4; } + + getProfile().TPSTART(QUDA_PROFILE_COMPUTE); + instantiate(u, rot, quality, dir_ignore, compute_theta); + getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); + } + +} // namespace quda diff --git a/lib/gauge_rotate.cu b/lib/gauge_rotate.cu new file mode 100644 index 0000000000..5f6f8029e1 --- /dev/null +++ b/lib/gauge_rotate.cu @@ -0,0 +1,62 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace quda +{ + + template class GaugeRotate : TunableKernel3D + { + GaugeField &out; + const GaugeField ∈ + const GaugeField &rot; + + unsigned int minThreads() const { return in.LocalVolumeCB(); } + + public: + GaugeRotate(GaugeField &out, const GaugeField &in, const GaugeField &rot) : + TunableKernel3D(in, 2, 4), out(out), in(in), rot(rot) + { + apply(device::get_default_stream()); + } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + launch(tp, stream, RotateGaugeArg(out, in, rot)); + } + + void preTune() { out.backup(); } // defensive measure in case they alias + void postTune() { out.restore(); } + + long long flops() const + { + auto mat_flops = in.Ncolor() * in.Ncolor() * (8ll * in.Ncolor() - 2ll); + return 2 * mat_flops * 4 * in.LocalVolume(); + } + long long bytes() const // 2 rot, 1 in, 1 out, per dim. + { + return (2 * rot.Reconstruct() * rot.Precision() + in.Reconstruct() * in.Precision() + + out.Reconstruct() * out.Precision()) + * 4 * in.LocalVolume(); + } + + }; // RotateGauge + + void gaugeRotate(GaugeField &out, const GaugeField &in, const GaugeField &rot) + { + checkPrecision(out, in, rot); + checkReconstruct(out, in, rot); + checkNative(out, in, rot); + + getProfile().TPSTART(QUDA_PROFILE_COMPUTE); + instantiate(out, in, rot); + getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); + } + +} // namespace quda diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 67ec31436c..2131bbfb0d 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -167,7 +167,13 @@ static TimeProfile profileInvertMultiSrc("invertMultiSrcQuda"); static TimeProfile profileUpdateSplitGauge("UpdateSplitGauge"); //!< Profiler for invertMultiShiftQuda -static TimeProfile profileMulti("invertMultiShiftQuda"); +static TimeProfile profileInvertMultiShift("invertMultiShiftQuda"); + +//!< Profiler for MatQuda +static TimeProfile profileMat("MatQuda"); + +//!< Profiler for MatDagMatQuda +static TimeProfile profileMatDagMat("MatDagMatQuda"); //!< Profiler for eigensolveQuda static TimeProfile profileEigensolve("eigensolveQuda"); @@ -253,6 +259,12 @@ static TimeProfile profileMomAction("momActionQuda"); //!< Profiler for sink projection static TimeProfile profileSinkProject("sinkProjectQuda"); +//!< Profiler for performGaugeRotateQuda +static TimeProfile profileGaugeRotate("performGaugeRotateQuda"); + +//!< Profiler for performGaugeFixQuda +static TimeProfile profileGaugeFix("performGaugeFixQuda"); + //!< Profiler for endQuda static TimeProfile profileEnd("endQuda"); @@ -467,7 +479,6 @@ static void init_default_comms() #endif } - extern char* gitversion; /* @@ -1511,12 +1522,18 @@ void endQuda(void) profileDslash.Print(); profileInvert.Print(); profileInvertMultiSrc.Print(); - profileMulti.Print(); + profileInvertMultiShift.Print(); + profileMat.Print(); + profileMatDagMat.Print(); profileEigensolve.Print(); profileFatLink.Print(); profileGaugeForce.Print(); profileGaugeUpdate.Print(); profileExtendedGauge.Print(); + profileGaugeRotate.Print(); + profileGaugeFix.Print(); + GaugeFixFFTQuda.Print(); + GaugeFixOVRQuda.Print(); profileCloverForce.Print(); profileTMCloverForce.Print(); profileStaggeredForce.Print(); @@ -1639,6 +1656,7 @@ namespace quda { case QUDA_COVDEV_DSLASH: diracParam.type = QUDA_GAUGE_COVDEV_DIRAC; diracParam.covdev_mu = inv_param->covdev_mu; + diracParam.covdev_shift = inv_param->covdev_shift; break; default: errorQuda("Unsupported dslash_type %d", inv_param->dslash_type); @@ -2333,6 +2351,7 @@ void covDevQuda(void *h_out, void *h_in, int dir, QudaInvertParam *param) void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) { + auto profile = pushProfile(profileMat, inv_param); pushVerbosity(inv_param->verbosity); const auto &gauge = (inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; @@ -2360,6 +2379,8 @@ void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) cudaParam.location = QUDA_CUDA_FIELD_LOCATION; ColorSpinorField out(cudaParam); + profileMat.TPSTART(QUDA_PROFILE_COMPUTE); + DiracParam diracParam; setDiracParam(diracParam, inv_param, pc); @@ -2385,6 +2406,8 @@ void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) } } + profileMat.TPSTOP(QUDA_PROFILE_COMPUTE); + cpuParam.v = h_out; cpuParam.location = inv_param->output_location; ColorSpinorField out_h(cpuParam); @@ -2397,6 +2420,7 @@ void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) void MatDagMatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) { + auto profile = pushProfile(profileMatDagMat, inv_param); pushVerbosity(inv_param->verbosity); const auto &gauge = (inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; @@ -2426,6 +2450,8 @@ void MatDagMatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) // double kappa = inv_param->kappa; // if (inv_param->dirac_order == QUDA_CPS_WILSON_DIRAC_ORDER) kappa *= gaugePrecise->anisotropy; + profileMatDagMat.TPSTART(QUDA_PROFILE_COMPUTE); + DiracParam diracParam; setDiracParam(diracParam, inv_param, pc); @@ -2451,6 +2477,8 @@ void MatDagMatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) } } + profileMatDagMat.TPSTOP(QUDA_PROFILE_COMPUTE); + cpuParam.v = h_out; cpuParam.location = inv_param->output_location; ColorSpinorField out_h(cpuParam); @@ -3591,7 +3619,7 @@ void dslashMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, Quda */ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) { - auto profile = pushProfile(profileMulti, param); + auto profile = pushProfile(profileInvertMultiShift, param); profilerStart(__func__); if (!initialized) errorQuda("QUDA not initialized"); @@ -3724,7 +3752,7 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) std::vector &x = solutionResident; std::vector p; - profileMulti.TPSTART(QUDA_PROFILE_PREAMBLE); + profileInvertMultiShift.TPSTART(QUDA_PROFILE_PREAMBLE); // Check source norms double nb = blas::norm2(b); @@ -3740,7 +3768,7 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) // rescale massRescale(b, *param, true); - profileMulti.TPSTOP(QUDA_PROFILE_PREAMBLE); + profileInvertMultiShift.TPSTOP(QUDA_PROFILE_PREAMBLE); DiracMatrix *m, *mSloppy; @@ -3899,11 +3927,11 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) if (!param->make_resident_solution) h_x[i] = x[i]; } - profileMulti.TPSTART(QUDA_PROFILE_EPILOGUE); + profileInvertMultiShift.TPSTART(QUDA_PROFILE_EPILOGUE); if (!param->make_resident_solution) solutionResident.clear(); - profileMulti.TPSTOP(QUDA_PROFILE_EPILOGUE); + profileInvertMultiShift.TPSTOP(QUDA_PROFILE_EPILOGUE); delete d; delete dSloppy; @@ -6062,7 +6090,148 @@ void performAdjGFlowHier(void **h_out, void **h_in, QudaInvertParam *inv_param, popVerbosity(); } -/* save list of gauge vectors */ +void performGaugeRotateQuda(void *rotation, void *gauge, QudaGaugeParam *param) +{ + auto profile = pushProfile(profileGaugeRotate); + checkGaugeParam(param); + lat_dim_t R1; + for (int d = 0; d < 4; d++) { R1[d] = (redundant_comms || commDimPartitioned(d)); } + + GaugeFieldParam gParam(*param); + gParam.location = QUDA_CPU_FIELD_LOCATION; + gParam.gauge = rotation; + gParam.geometry = QUDA_SCALAR_GEOMETRY; + GaugeField cpuRotation(gParam); + gParam.gauge = gauge; + gParam.geometry = QUDA_VECTOR_GEOMETRY; + GaugeField cpuGauge = (!param->use_resident_gauge || param->return_result_gauge) ? GaugeField(gParam) : GaugeField(); + + gParam.create = QUDA_NULL_FIELD_CREATE; + gParam.location = QUDA_CUDA_FIELD_LOCATION; + gParam.link_type = param->type; + gParam.reconstruct = param->reconstruct; + gParam.setPrecision(param->cuda_prec, true); + gParam.geometry = QUDA_SCALAR_GEOMETRY; + GaugeField cudaRotation(gParam); + gParam.geometry = QUDA_VECTOR_GEOMETRY; + GaugeField cudaInGauge = param->use_resident_gauge ? gaugePrecise->create_alias() : GaugeField(gParam); + GaugeField cudaOutGauge(gParam); + + cudaRotation.copy(cpuRotation); + if (!param->use_resident_gauge) { cudaInGauge.copy(cpuGauge); } + + GaugeField *cudaRotationEx = createExtendedGauge(cudaRotation, R1, profileGaugeRotate); + + gaugeRotate(cudaOutGauge, cudaInGauge, *cudaRotationEx); + + delete cudaRotationEx; + if (param->return_result_gauge) { cpuGauge.copy(cudaOutGauge); } + if (param->make_resident_gauge) { + freeUniqueGaugeQuda(QUDA_WILSON_LINKS); + gaugePrecise = new GaugeField(); + std::exchange(*gaugePrecise, cudaOutGauge); + updateExtendedGaugeResident(true, R, profileGaugeRotate); + } +} + +void performGaugeFixQuda(void *rotation, void *gauge, QudaGaugeParam *param, QudaGaugeFixParam *fix_param) +{ + auto profile = pushProfile(profileGaugeFix); + checkGaugeParam(param); + checkGaugeFixParam(fix_param); + int *reunit_fails_h = static_cast(mapped_malloc(sizeof(int))); + int *reunit_fails_d = static_cast(get_mapped_device_pointer(reunit_fails_h)); + lat_dim_t R1; + for (int d = 0; d < 4; d++) { R1[d] = (redundant_comms || commDimPartitioned(d)); } + + GaugeFieldParam gParam(*param); + gParam.location = QUDA_CPU_FIELD_LOCATION; + gParam.gauge = rotation; + gParam.geometry = QUDA_SCALAR_GEOMETRY; + GaugeField cpuRotation(gParam); + gParam.gauge = gauge; + gParam.geometry = QUDA_VECTOR_GEOMETRY; + GaugeField cpuGauge = (!param->use_resident_gauge || param->return_result_gauge) ? GaugeField(gParam) : GaugeField(); + + if (param->use_resident_gauge && !gaugePrecise) errorQuda("No resident gauge field to use"); + gParam.create = QUDA_NULL_FIELD_CREATE; + gParam.location = QUDA_CUDA_FIELD_LOCATION; + gParam.link_type = param->type; + gParam.reconstruct = param->reconstruct; + gParam.setPrecision(param->cuda_prec, true); + gParam.geometry = QUDA_SCALAR_GEOMETRY; + GaugeField cudaRotation(gParam); + gParam.geometry = QUDA_VECTOR_GEOMETRY; + GaugeField cudaInGauge = param->use_resident_gauge ? gaugePrecise->create_alias() : GaugeField(gParam); + GaugeField cudaOutGauge + = (param->make_resident_gauge || param->return_result_gauge) ? GaugeField(gParam) : GaugeField(); + + cudaRotation.copy(cpuRotation); + if (!param->use_resident_gauge) { cudaInGauge.copy(cpuGauge); } + + GaugeField *cudaRotationEx = createExtendedGauge(cudaRotation, R1, profileGaugeFix); + GaugeField *cudaInGaugeEx = createExtendedGauge(cudaInGauge, R1, profileGaugeFix); + + double functional_old, functional, theta, diff, criterion, quality[2]; + bool compute_theta = fix_param->compute_theta; + bool use_theta = fix_param->use_theta; + if (use_theta && !compute_theta) { errorQuda("compute_theta must be true if use_theta is true"); } + gaugeFixQuality(quality, *cudaRotationEx, *cudaInGaugeEx, fix_param->dir_ignore, compute_theta); + functional = quality[0]; + theta = quality[1]; + diff = 1.0; + criterion = use_theta ? theta : diff; + int iter = 0; + logQuda(QUDA_SUMMARIZE, "%d iter: functional=%.15f, functional diff=%le, theta=%le\n", iter, functional, diff, theta); + while (iter < fix_param->maxiter && criterion > fix_param->tol) { + gaugeFixOVRStep(*cudaRotationEx, *cudaInGaugeEx, fix_param->omega, fix_param->dir_ignore); + gaugeFixQuality(quality, *cudaRotationEx, *cudaInGaugeEx, fix_param->dir_ignore, compute_theta); + functional_old = functional; + functional = quality[0]; + theta = quality[1]; + diff = fabs((functional - functional_old) / functional_old); + criterion = use_theta ? theta : diff; + iter++; + if (iter % fix_param->reunit_interval == 0) { + *reunit_fails_h = 0; + unitarizeLinks(*cudaRotationEx, *cudaRotationEx, reunit_fails_d); + if (*reunit_fails_h > 0) errorQuda("Error in the unitarization (%d errors)\n", *reunit_fails_h); + } + if (iter % fix_param->verbose_interval == 0) { + logQuda(QUDA_SUMMARIZE, "%d iter: functional=%.15f, functional diff=%le, theta=%le\n", iter, functional, diff, + theta); + } + } + if (iter < fix_param->maxiter) { + if (iter % fix_param->reunit_interval != 0) { + *reunit_fails_h = 0; + unitarizeLinks(*cudaRotationEx, *cudaRotationEx, reunit_fails_d); + if (*reunit_fails_h > 0) errorQuda("Error in the unitarization (%d errors)\n", *reunit_fails_h); + } + if (iter % fix_param->verbose_interval != 0) { + logQuda(QUDA_SUMMARIZE, "%d iter: functional=%.15f, functional diff=%le, theta=%le\n", iter, functional, diff, + theta); + } + } + + // copy the field back to the host + copyExtendedGauge(cudaRotation, *cudaRotationEx, QUDA_CUDA_FIELD_LOCATION); + cpuRotation.copy(cudaRotation); + if (param->make_resident_gauge || param->return_result_gauge) { + gaugeRotate(cudaOutGauge, cudaInGauge, *cudaRotationEx); + } + if (param->return_result_gauge) { cpuGauge.copy(cudaOutGauge); } + + host_free(reunit_fails_h); + delete cudaRotationEx; + delete cudaInGaugeEx; + if (param->make_resident_gauge) { + freeUniqueGaugeQuda(QUDA_WILSON_LINKS); + gaugePrecise = new GaugeField(); + std::exchange(*gaugePrecise, cudaOutGauge); + updateExtendedGaugeResident(true, R, profileGaugeFix); + } +} int computeGaugeFixingOVRQuda(void *gauge, const unsigned int gauge_dir, const unsigned int Nsteps, const unsigned int verbose_interval, const double relax_boost, const double tolerance, diff --git a/lib/pgauge_init.cu b/lib/pgauge_init.cu index 892201a672..a2293264a9 100644 --- a/lib/pgauge_init.cu +++ b/lib/pgauge_init.cu @@ -25,15 +25,23 @@ namespace quda { void apply(const qudaStream_t &stream) { TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); - launch(tp, stream, InitGaugeColdArg(U)); + if (U.Geometry() == QUDA_SCALAR_GEOMETRY) { + launch(tp, stream, InitGaugeColdArg(U)); + } else if (U.Geometry() == QUDA_VECTOR_GEOMETRY) { + launch(tp, stream, InitGaugeColdArg(U)); + } else if (U.Geometry() == QUDA_TENSOR_GEOMETRY) { + launch(tp, stream, InitGaugeColdArg(U)); + } else { + errorQuda("Unsupported geometry %d\n", U.Geometry()); + } } long long flops() const { return 0; } long long bytes() const { return U.Bytes(); } }; - template - class InitGaugeHot : TunableKernel1D { + template class InitGaugeHot : TunableKernel1D + { const GaugeField &U; RNG &rng; unsigned int minThreads() const { return U.LocalVolumeCB(); } @@ -52,7 +60,15 @@ namespace quda { void apply(const qudaStream_t &stream) { TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); - launch(tp, stream, InitGaugeHotArg(U, rng.State())); + if (U.Geometry() == QUDA_SCALAR_GEOMETRY) { + launch(tp, stream, InitGaugeHotArg(U, rng.State())); + } else if (U.Geometry() == QUDA_VECTOR_GEOMETRY) { + launch(tp, stream, InitGaugeHotArg(U, rng.State())); + } else if (U.Geometry() == QUDA_TENSOR_GEOMETRY) { + launch(tp, stream, InitGaugeHotArg(U, rng.State())); + } else { + errorQuda("Unsupported geometry %d\n", U.Geometry()); + } } void preTune() { rng.backup(); } diff --git a/lib/solve.cpp b/lib/solve.cpp index 79fa868633..2cb2a520af 100644 --- a/lib/solve.cpp +++ b/lib/solve.cpp @@ -113,7 +113,7 @@ namespace quda } if (param.inv_type == QUDA_MG_INVERTER) errorQuda("Multigrid solver doesn't support distance preconditioning"); - if (param.cuda_prec != QUDA_DOUBLE_PRECISION || param.cuda_prec_sloppy != QUDA_DOUBLE_PRECISION) { + if (param.cuda_prec != QUDA_DOUBLE_PRECISION) { warningQuda( "Using single or half (sloppy) precision in distance preconditioning sometimes makes the solver diverge"); } @@ -134,6 +134,9 @@ namespace quda bool direct_solve = (param.solve_type == QUDA_DIRECT_SOLVE) || (param.solve_type == QUDA_DIRECT_PC_SOLVE); bool norm_error_solve = (param.solve_type == QUDA_NORMERR_SOLVE) || (param.solve_type == QUDA_NORMERR_PC_SOLVE); + bool distance_pc = (param.distance_pc_alpha0 != 0.0 && param.distance_pc_t0 >= 0); + distanceReweight(b, param, true); + auto nb = blas::norm2(b); for (auto &bi : nb) { if (bi == 0.0) errorQuda("Source has zero norm"); @@ -144,7 +147,8 @@ namespace quda for (auto &xi : x_norm) logQuda(QUDA_VERBOSE, "Initial guess: %g\n", xi); } // rescale the source and solution vectors to help prevent the onset of underflow - if (param.solver_normalization == QUDA_SOURCE_NORMALIZATION) { + // and force to normalize the source if distance preconditioning is used + if (param.solver_normalization == QUDA_SOURCE_NORMALIZATION || distance_pc) { auto nb_inv(nb); for (auto &bi : nb_inv) bi = 1 / sqrt(bi); blas::ax(nb_inv, b); @@ -152,7 +156,6 @@ namespace quda } massRescale(b, param, false); - distanceReweight(b, param, true); std::vector in(b.size()); std::vector out(b.size()); @@ -295,9 +298,7 @@ namespace quda dirac.reconstruct(x, b, param.solution_type); - distanceReweight(x, param, false); - - if (param.solver_normalization == QUDA_SOURCE_NORMALIZATION) { + if (param.solver_normalization == QUDA_SOURCE_NORMALIZATION || distance_pc) { // rescale the solution for (auto &bi : nb) bi = sqrt(bi); blas::ax(nb, x); @@ -314,6 +315,8 @@ namespace quda param.action[1] = action[0].imag(); } + distanceReweight(x, param, false); + getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); } diff --git a/lib/spinor_reweight.cu b/lib/spinor_reweight.cu index 13c45dc83c..20ff9331b1 100644 --- a/lib/spinor_reweight.cu +++ b/lib/spinor_reweight.cu @@ -9,7 +9,7 @@ namespace quda template class SpinorDistanceReweight : TunableKernel2D { ColorSpinorField &v; - Float alpha0; + double alpha0; int t0; unsigned int minThreads() const { return v.VolumeCB(); } diff --git a/lib/unitarize_links_quda.cu b/lib/unitarize_links_quda.cu index 8bddafea55..69e807a798 100644 --- a/lib/unitarize_links_quda.cu +++ b/lib/unitarize_links_quda.cu @@ -108,11 +108,8 @@ namespace quda { unsigned int minThreads() const { return in.VolumeCB(); } public: - UnitarizeLinks(GaugeField &out, const GaugeField &in, int* fails) : - TunableKernel3D(in, 2, 4), - out(out), - in(in), - fails(fails) + UnitarizeLinks(GaugeField &out, const GaugeField &in, int *fails) : + TunableKernel3D(in, 2, in.Geometry()), out(out), in(in), fails(fails) { apply(device::get_default_stream()); qudaDeviceSynchronize(); // need to synchronize to ensure failure write has completed diff --git a/tests/covdev_test.cpp b/tests/covdev_test.cpp index 2f9e804266..06981e311f 100644 --- a/tests/covdev_test.cpp +++ b/tests/covdev_test.cpp @@ -47,7 +47,11 @@ void init(int argc, char **argv) // Allocate host side memory for the gauge field. for (int dir = 0; dir < 4; dir++) { links[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); } - constructHostGaugeField(links, gauge_param, argc, argv); + if (covdev_shift) { + constructIdentityGaugeField(links, gauge_param.cpu_prec); + } else { + constructHostGaugeField(links, gauge_param, argc, argv); + } // cpuLink is only used for ghost allocation GaugeFieldParam cpuParam(gauge_param, links); @@ -61,12 +65,12 @@ void end(void) cpuLink = {}; } -double dslashCUDA(GaugeCovDev &dirac, ColorSpinorField &out, const ColorSpinorField &in, int niter, int mu) +double dslashCUDA(GaugeCovDev &dirac, ColorSpinorField &out, const ColorSpinorField &in, int niter) { device_timer_t timer; timer.start(); - for (int i = 0; i < niter; i++) dirac.MCD(out, in, mu); + for (int i = 0; i < niter; i++) dirac.M(out, in); timer.stop(); return timer.last(); @@ -85,6 +89,7 @@ std::array covdev_test(test_t param) QudaPrecision test_prec = ::testing::get<0>(param); QudaDagType test_dagger = ::testing::get<1>(param); int mu = ::testing::get<2>(param); + bool shift = ::testing::get<3>(param); printfQuda("Links sending..."); gauge_param.cuda_prec = test_prec; @@ -101,6 +106,10 @@ std::array covdev_test(test_t param) inv_param.dslash_type = QUDA_COVDEV_DSLASH; // ensure we use the correct dslash inv_param.solution_type = QUDA_MAT_SOLUTION; + inv_param.covdev_nspin = test_type == 0 ? 4 : 1; + inv_param.covdev_mu = mu + (test_dagger ? 4 : 0); + inv_param.covdev_shift = shift; + ColorSpinorParam csParam; csParam.nColor = nColor; csParam.nSpin = test_type == 0 ? 4 : 1; // use --test 1 for staggered case @@ -153,12 +162,12 @@ std::array covdev_test(test_t param) { // warm-up run printfQuda("Tuning...\n"); - dslashCUDA(dirac, cudaSpinorOut, cudaSpinor, 1, muQuda); + dslashCUDA(dirac, cudaSpinorOut, cudaSpinor, 1); } printfQuda("Executing %d kernel loop(s)...", niter); auto flops0 = quda::Tunable::flops_global(); - double secs = dslashCUDA(dirac, cudaSpinorOut, cudaSpinor, niter, muQuda); + double secs = dslashCUDA(dirac, cudaSpinorOut, cudaSpinor, niter); auto flops = (quda::Tunable::flops_global() - flops0); spinorOut = cudaSpinorOut; @@ -206,7 +215,7 @@ int main(int argc, char **argv) if (enable_testing) { // tests are defined in invert_test_gtest.hpp result = test.execute(); } else { // - covdev_test(test_t {prec, dagger ? QUDA_DAG_YES : QUDA_DAG_NO, covdev_mu}); + covdev_test(test_t {prec, dagger ? QUDA_DAG_YES : QUDA_DAG_NO, covdev_mu, covdev_shift}); } end(); diff --git a/tests/covdev_test_gtest.hpp b/tests/covdev_test_gtest.hpp index 4e0225bdcd..57c16a6d4b 100644 --- a/tests/covdev_test_gtest.hpp +++ b/tests/covdev_test_gtest.hpp @@ -2,7 +2,7 @@ #include #include -using test_t = ::testing::tuple; +using test_t = ::testing::tuple; bool skip_test(test_t param) { @@ -41,6 +41,7 @@ std::string gettestname(::testing::TestParamInfo param) str += get_prec_str(::testing::get<0>(param.param)); str += std::string("_") + get_dag_str(::testing::get<1>(param.param)); str += std::string("_mu") + std::to_string(::testing::get<2>(param.param)); + if (::testing::get<3>(param.param)) str += std::string("_shift"); return str; } @@ -51,5 +52,6 @@ using ::testing::Values; auto precisions = Values(QUDA_DOUBLE_PRECISION, QUDA_SINGLE_PRECISION, QUDA_HALF_PRECISION); auto dagger_opt = Values(QUDA_DAG_YES, QUDA_DAG_NO); auto mu_values = Values(0, 1, 2, 3); +auto shift_values = Values(false, true); -INSTANTIATE_TEST_SUITE_P(covdevtst, CovDevTest, Combine(precisions, dagger_opt, mu_values), gettestname); +INSTANTIATE_TEST_SUITE_P(covdevtst, CovDevTest, Combine(precisions, dagger_opt, mu_values, shift_values), gettestname); diff --git a/tests/gauge_alg_test.cpp b/tests/gauge_alg_test.cpp index b7cc48fe7b..f586c7ed73 100644 --- a/tests/gauge_alg_test.cpp +++ b/tests/gauge_alg_test.cpp @@ -76,6 +76,81 @@ struct GaugeAlgTest : public ::testing::TestWithParam { return (std::abs(1.0 - detu.x) < prec_val && std::abs(detu.y) < prec_val); } + void gaugeFixOVR_v2(GaugeField &gauge, int dir_ignore, double tol, int maxiter, double omega, bool use_theta, + int reunit_interval, int verbose_interval) + { + lat_dim_t R = {0, 0, 0, 0}; + for (int d = 0; d < 4; d++) { + if (comm_dim_partitioned(d)) R[d] = 2; + } + static TimeProfile GaugeFix("GaugeFix"); + int *reunit_fails_h = static_cast(mapped_malloc(sizeof(int))); + int *reunit_fails_d = static_cast(get_mapped_device_pointer(reunit_fails_h)); + + GaugeFieldParam gauge_field_param(param, nullptr); + gauge_field_param.ghostExchange = QUDA_GHOST_EXCHANGE_NO; + gauge_field_param.location = QUDA_CUDA_FIELD_LOCATION; + gauge_field_param.create = QUDA_NULL_FIELD_CREATE; + gauge_field_param.reconstruct = param.reconstruct; + gauge_field_param.setPrecision(precision, true); + gauge_field_param.geometry = QUDA_SCALAR_GEOMETRY; + GaugeField *tmp = new GaugeField(gauge_field_param); + InitGaugeField(*tmp); + GaugeField *rot = createExtendedGauge(*tmp, R, GaugeFix); + delete tmp; + gauge_field_param.geometry = QUDA_VECTOR_GEOMETRY; + tmp = new GaugeField(gauge_field_param); + copyExtendedGauge(*tmp, gauge, QUDA_CUDA_FIELD_LOCATION); + + double functional_old, functional, theta, diff, criterion, quality[2]; + bool compute_theta = true; + if (use_theta && !compute_theta) { errorQuda("compute_theta must be true if use_theta is true"); } + gaugeFixQuality(quality, *rot, gauge, dir_ignore, compute_theta); + functional = quality[0]; + theta = quality[1]; + diff = 1.0; + criterion = use_theta ? theta : diff; + int iter = 0; + logQuda(QUDA_SUMMARIZE, "%d iter: functional=%.15f, functional diff=%le, theta=%le\n", iter, functional, diff, theta); + while (iter < maxiter && criterion > tol) { + gaugeFixOVRStep(*rot, gauge, omega, dir_ignore); + gaugeFixQuality(quality, *rot, gauge, dir_ignore, compute_theta); + functional_old = functional; + functional = quality[0]; + theta = quality[1]; + diff = fabs((functional - functional_old) / functional_old); + criterion = use_theta ? theta : diff; + iter++; + if (iter % reunit_interval == 0) { + *reunit_fails_h = 0; + unitarizeLinks(*rot, *rot, reunit_fails_d); + if (*reunit_fails_h > 0) errorQuda("Error in the unitarization (%d errors)\n", *reunit_fails_h); + } + if (iter % verbose_interval == 0) { + logQuda(QUDA_SUMMARIZE, "%d iter: functional=%.15f, functional diff=%le, theta=%le\n", iter, functional, diff, + theta); + } + } + if (iter < maxiter) { + if (iter % reunit_interval != 0) { + *reunit_fails_h = 0; + unitarizeLinks(*rot, *rot, reunit_fails_d); + if (*reunit_fails_h > 0) errorQuda("Error in the unitarization (%d errors)\n", *reunit_fails_h); + } + if (iter % verbose_interval != 0) { + logQuda(QUDA_SUMMARIZE, "%d iter: functional=%.15f, functional diff=%le, theta=%le\n", iter, functional, diff, + theta); + } + } + + host_free(reunit_fails_h); + gaugeRotate(*tmp, *tmp, *rot); + delete rot; + copyExtendedGauge(gauge, *tmp, QUDA_CUDA_FIELD_LOCATION); + delete tmp; + gauge.exchangeExtendedGhost(gauge.R(), false); + } + virtual void SetUp() { #ifndef QUDA_BUILD_NATIVE_FFT // skip FFT tests if FFT not available @@ -201,6 +276,7 @@ struct GaugeAlgTest : public ::testing::TestWithParam { break; case 1: run_ovr(); break; case 2: run_fft(); break; + case 3: run_ovr2(); break; default: errorQuda("Invalid test type %d", test_type); } @@ -259,6 +335,19 @@ struct GaugeAlgTest : public ::testing::TestWithParam { } } } + virtual void run_ovr2() + { + if (execute) { + gaugeFixOVR_v2(*U, gf_gauge_dir, gf_tolerance, gf_maxiter, gf_ovr_relaxation_boost, gf_theta_condition, + gf_reunit_interval, gf_verbosity_interval); + auto plaq_gf = plaquette(*U); + printfQuda("Plaq: %.16e, %.16e, %.16e\n", plaq.x, plaq.y, plaq.z); + printfQuda("Plaq GF: %.16e, %.16e, %.16e\n", plaq_gf.x, plaq_gf.y, plaq_gf.z); + ASSERT_TRUE(comparePlaquette(plaq, plaq_gf)); + // Save if output string is specified + if (gauge_store) save_gauge(); + } + } virtual void save_gauge() { @@ -356,6 +445,32 @@ TEST_P(GaugeAlgTest, Coulomb_FFT) } } +TEST_P(GaugeAlgTest, Landau_Overrelaxation_v2) +{ + if (execute) { + printfQuda("Landau gauge fixing with overrelaxation v2\n"); + gaugeFixOVR_v2(*U, 4, gf_tolerance, gf_maxiter, gf_ovr_relaxation_boost, gf_theta_condition, gf_reunit_interval, + gf_verbosity_interval); + auto plaq_gf = plaquette(*U); + printfQuda("Plaq: %.16e, %.16e, %.16e\n", plaq.x, plaq.y, plaq.z); + printfQuda("Plaq GF: %.16e, %.16e, %.16e\n", plaq_gf.x, plaq_gf.y, plaq_gf.z); + ASSERT_TRUE(comparePlaquette(plaq, plaq_gf)); + } +} + +TEST_P(GaugeAlgTest, Coulomb_Overrelaxation_v2) +{ + if (execute) { + printfQuda("Coulomb gauge fixing with overrelaxation v2\n"); + gaugeFixOVR_v2(*U, 3, gf_tolerance, gf_maxiter, gf_ovr_relaxation_boost, gf_theta_condition, gf_reunit_interval, + gf_verbosity_interval); + auto plaq_gf = plaquette(*U); + printfQuda("Plaq: %.16e, %.16e, %.16e\n", plaq.x, plaq.y, plaq.z); + printfQuda("Plaq GF: %.16e, %.16e, %.16e\n", plaq_gf.x, plaq_gf.y, plaq_gf.z); + ASSERT_TRUE(comparePlaquette(plaq, plaq_gf)); + } +} + struct gauge_alg_test : quda_test { void display_info() const override diff --git a/tests/utils/command_line_params.cpp b/tests/utils/command_line_params.cpp index 8eaddb4d80..48e7d2998c 100644 --- a/tests/utils/command_line_params.cpp +++ b/tests/utils/command_line_params.cpp @@ -343,6 +343,7 @@ int prop_n_sources = 1; QudaPrecision prop_save_prec = QUDA_SINGLE_PRECISION; int covdev_mu = 3; +bool covdev_shift = false; // Parameters for the (gaussian) quark smearing operator int smear_n_steps = 50; @@ -1353,4 +1354,5 @@ void add_covdev_option_group(std::shared_ptr quda_app) { auto opgroup = quda_app->add_option_group("Covdev", "Options controlling cov derivative parameteres"); opgroup->add_option("--covdev-mu", covdev_mu, "Set the direction for the covariant derivative"); + opgroup->add_option("--covdev-shift", covdev_shift, "Apply simple shift instead of the covariant derivative"); } diff --git a/tests/utils/command_line_params.h b/tests/utils/command_line_params.h index 80d9c85080..5be47b2ddb 100644 --- a/tests/utils/command_line_params.h +++ b/tests/utils/command_line_params.h @@ -607,3 +607,4 @@ extern bool enable_testing; extern bool detratio; extern int covdev_mu; +extern bool covdev_shift;