Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f7dcc56
Add new gauge fixing algorithm to return the rotation field.
SaltyChiang Jul 23, 2024
2db02fb
Add over relaxation.
SaltyChiang Jul 23, 2024
39844da
Make compiler happy.
SaltyChiang Jul 23, 2024
07ec4f7
Add docstring.
SaltyChiang Jul 23, 2024
3f98f5d
Don't return `double2` struct.
SaltyChiang Jul 23, 2024
6420ec3
Add new members to `QudaGaugeParam` to control the gauge fixing.
SaltyChiang Jul 24, 2024
7e168a2
Merge remote-tracking branch 'upstream/develop' into feature/new-gaug…
SaltyChiang Apr 30, 2025
5f6ff83
Add GaugeFixParam for gauge fixing algorithms.
SaltyChiang May 1, 2025
a4cb908
Add `gaugeRotateQuda` and `spinorRotateQuda` interface.
SaltyChiang May 3, 2025
ec06e81
Enable shift only mode for the covariant derivative kernel.
SaltyChiang May 7, 2025
e8e69a8
Use `gaugePrecise` for `performGaugeFixQuda` and `performGaugeRotateQ…
SaltyChiang May 8, 2025
91f6b4c
Merge remote-tracking branch 'upstream/develop' into feature/new-gaug…
SaltyChiang May 8, 2025
4603047
Update aux string.
SaltyChiang May 8, 2025
377dfdc
Fusing `gaugeRotate` and `gaugeFixQuality` into a single kernel.
SaltyChiang May 9, 2025
8a996fe
Fix possible divergence if distance preconditioning is used.
SaltyChiang May 9, 2025
5c6b0fa
Merge remote-tracking branch 'upstream/develop' into feature/new-gaug…
SaltyChiang May 9, 2025
24a11d1
Fix aux strings.
SaltyChiang May 10, 2025
598eb4d
Add gauge fixing test v2 to gauge_alg_test.cpp.
SaltyChiang May 10, 2025
30547b7
Fix potential nan for gauge fixing with fp32.
SaltyChiang May 11, 2025
4edc1ba
Add the test for shift-only mode in QUDA_COVDEV_DSLASH.
SaltyChiang May 11, 2025
5ab6ad6
Apply clang-format.
SaltyChiang May 11, 2025
5771a34
Use `laplace_nspin` and `covdev_nspin` instead of an ambiguous `stagg…
SaltyChiang May 16, 2025
6ceef5b
Fix a bug in covdev_test.
SaltyChiang May 23, 2025
1d0acea
Merge remote-tracking branch 'upstream/develop' into feature/new-gaug…
SaltyChiang May 23, 2025
9f6249c
Merge branch 'develop' into feature/new-gauge-fixing-ovr
SaltyChiang Jun 5, 2025
3f8e2f6
Fix type of `alpha0`.
SaltyChiang Jun 5, 2025
c6e9ad4
Merge branch 'develop' into feature/new-gauge-fixing-ovr
SaltyChiang Nov 14, 2025
6119bfd
Add covdev test with shift.
SaltyChiang Nov 14, 2025
6aa7bb1
Fix bug in covdev_test.
SaltyChiang Nov 14, 2025
dcfe930
Merge remote-tracking branch 'upstream/develop' into feature/new-gaug…
SaltyChiang Dec 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/color_spinor_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ namespace quda
QudaFieldLocation location = QUDA_CPU_FIELD_LOCATION) :
LatticeFieldParam(4, X, 0, location, inv_param.cpu_prec),
nColor(3),
nSpin((inv_param.dslash_type == QUDA_ASQTAD_DSLASH || inv_param.dslash_type == QUDA_STAGGERED_DSLASH
|| inv_param.dslash_type == QUDA_LAPLACE_DSLASH) ?
1 :
4),
nSpin((inv_param.dslash_type == QUDA_LAPLACE_DSLASH) ? inv_param.laplace_nspin :
(inv_param.dslash_type == QUDA_COVDEV_DSLASH) ? inv_param.covdev_nspin :
(inv_param.dslash_type == QUDA_ASQTAD_DSLASH || inv_param.dslash_type == QUDA_STAGGERED_DSLASH) ? 1 :
4),
twistFlavor(inv_param.twist_flavor),
gammaBasis(nSpin == 4 ? inv_param.gamma_basis : QUDA_DEGRAND_ROSSI_GAMMA_BASIS),
create(QUDA_REFERENCE_FIELD_CREATE),
Expand Down
8 changes: 8 additions & 0 deletions include/dirac_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace quda {
GaugeField *longGauge; // used by staggered only
int laplace3D;
int covdev_mu;
bool covdev_shift;
CloverField *clover;
GaugeField *xInvKD; // used for the Kahler-Dirac operator only

Expand Down Expand Up @@ -111,6 +112,7 @@ namespace quda {
printfQuda("mass = %g\n", mass);
printfQuda("laplace3D = %d\n", laplace3D);
printfQuda("covdev_mu = %d\n", covdev_mu);
printfQuda("covdev_shift = %d\n", covdev_shift);
printfQuda("m5 = %g\n", m5);
printfQuda("Ls = %d\n", Ls);
printfQuda("matpcType = %d\n", matpcType);
Expand Down Expand Up @@ -2249,6 +2251,7 @@ namespace quda {

protected:
int covdev_mu;
int covdev_shift;

public:
GaugeCovDev(const DiracParam &param);
Expand All @@ -2262,6 +2265,11 @@ namespace quda {
virtual void MCD(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const int mu) const;
virtual void MdagMCD(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const int mu) const;

virtual void DslashS(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, QudaParity parity,
int mu) const;
virtual void MS(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const int mu) const;
virtual void MdagMS(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const int mu) const;

virtual void Dslash(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in,
QudaParity parity) const override;
virtual void DslashXpay(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in,
Expand Down
3 changes: 2 additions & 1 deletion include/dslash_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,11 +793,12 @@ namespace quda
@param[in] mu Direction of the derivative. For mu > 3 it goes backwards
@param[in] parity Destination parity
@param[in] dagger Whether this is for the dagger operator
@param[in] shift Whether to apply the shift instead of the covariant derivative
@param[in] comm_override Override for which dimensions are partitioned
@param[in] profile The TimeProfile used for profiling the dslash
*/
void ApplyCovDev(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const GaugeField &U,
int mu, int parity, bool dagger, const int *comm_override, TimeProfile &profile);
int mu, int parity, bool dagger, bool shift, const int *comm_override, TimeProfile &profile);

/**
@brief Apply clover-matrix field to a color-spinor field
Expand Down
28 changes: 28 additions & 0 deletions include/gauge_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,34 @@ namespace quda
void GFlowStep(GaugeField &out, GaugeField &temp, GaugeField &in, double epsilon, QudaGaugeSmearType smear_type,
QudaWFlowStepType step_type);

/**
* @brief Rotate gauge field U_\mu(x) with rotation field g(x).
* U'_\mu(x) = g(x)U_\mu(x)g^\dagger(x+\mu)
* @param[in,out] out Rotated gauge field U'_\mu(x)
* @param[in] in Gauge field U_\mu(x)
* @param[in] rot Rotation field g(x)
*/
void gaugeRotate(GaugeField &out, const GaugeField &in, const GaugeField &rot);

/**
* @brief Gauge fixing with over-relaxation.
* @param[in,out] rot Rotation field to fix the gauge
* @param[in] u Gauge field
* @param[in] omega The over-relaxation parameter, most common value is 1.5 or 1.7
* @param[in] dir_ignore The ignored direction, 3 (Coulomb gauge) and 4 (Landau gauge) are common choices
*/
void gaugeFixOVRStep(GaugeField &rot, const GaugeField &u, double omega, int dir_ignore);

/**
* @brief Compute the gauge fixing quality, functional or theta is considered as the criterion.
* @param[in,out] quality The functional and theta value
* @param[in] u Fixed gauge field
* @param[in] rot Rotation field
* @param[in] dir_ignore The ignored direction, 3 (Coulomb gauge) and 4 (Landau gauge) are common choices
* @param[in] compute_theta Set to true to compute the theta value as the criterion
*/
void gaugeFixQuality(double quality[2], const GaugeField &rot, const GaugeField &u, int dir_ignore, bool compute_theta);

/**
* @brief Gauge fixing with overrelaxation with support for single and multi GPU.
* @param[in,out] data, quda gauge field
Expand Down
38 changes: 29 additions & 9 deletions include/kernels/covariant_derivative.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace quda
/**
@brief Parameter structure for driving the covariant derivative operator
*/
template <typename Float, int nSpin_, int nColor_, typename DDArg, QudaReconstructType reconstruct_, int nDim>
template <typename Float, int nSpin_, int nColor_, typename DDArg, QudaReconstructType reconstruct_, int nDim, bool shift_>
struct CovDevArg : DslashArg<Float, nDim, DDArg> {
static constexpr int nColor = nColor_;
static constexpr int nSpin = nSpin_;
Expand All @@ -29,6 +29,8 @@ namespace quda
static constexpr QudaGhostExchange ghost = QUDA_GHOST_EXCHANGE_PAD;
typedef typename gauge_mapper<Float, reconstruct, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost>::type G;

static constexpr bool shift = shift_;

typedef typename mapper<Float>::type real;

F out[MAX_MULTI_RHS]; /** output vector field */
Expand Down Expand Up @@ -82,25 +84,33 @@ namespace quda
const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc);
const bool ghost = (coord[d] + 1 >= arg.dc.X[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);

const Link U = arg.U(d, coord.x_cb, parity);

if (doHalo<kernel_type>(d) && ghost) {

const int ghost_idx = ghostFaceIndex<1>(coord, arg.dc.X, d, arg.nFace);
const Vector in = arg.halo.Ghost(d, 1, ghost_idx + src_idx * arg.dc.ghostFaceCB[d], their_spinor_parity);

out += U * in;
if constexpr (Arg::shift) {
out += in;
} else {
const Link U = arg.U(d, coord.x_cb, parity);
out += U * in;
}
} else if (doBulk<kernel_type>() && !ghost) {

const Vector in = arg.in[src_idx](fwd_idx, their_spinor_parity);
out += U * in;

if constexpr (Arg::shift) {
out += in;
} else {
const Link U = arg.U(d, coord.x_cb, parity);
out += U * in;
}
}

} else if (mu >= 4 && arg.dd_in.doHopping(coord, d, -1)) {
// Backward gather - compute back offset for spinor and gauge fetch

const int back_idx = getNeighborIndexCB(coord, d, -1, arg.dc);
const int gauge_idx = back_idx;

const bool ghost = (coord[d] - 1 < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);

Expand All @@ -110,13 +120,23 @@ namespace quda
const Link U = arg.U.Ghost(d, ghost_idx, 1 - parity);
const Vector in = arg.halo.Ghost(d, 0, ghost_idx + src_idx * arg.dc.ghostFaceCB[d], their_spinor_parity);

out += conj(U) * in;
if constexpr (Arg::shift) {
out += in;
} else {
const Link U = arg.U.Ghost(d, ghost_idx, 1 - parity);
out += conj(U) * in;
}
} else if (doBulk<kernel_type>() && !ghost) {

const Link U = arg.U(d, gauge_idx, 1 - parity);
const Vector in = arg.in[src_idx](back_idx, their_spinor_parity);

out += conj(U) * in;
if constexpr (Arg::shift) {
out += in;
} else {
const int gauge_idx = back_idx;
const Link U = Arg::shift ? Link() : arg.U(d, gauge_idx, 1 - parity);
out += conj(U) * in;
}
}
} // Forward/backward derivative
}
Expand Down
10 changes: 5 additions & 5 deletions include/kernels/dslash_wilson.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace quda
const G U; /** the gauge field */
const real a; /** xpay scale factor - can be -kappa or -kappa^2 */
/** parameters for distance preconditioning */
const real alpha0;
const double alpha0;
const int t0;

WilsonArg(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &halo,
Expand Down Expand Up @@ -88,10 +88,10 @@ namespace quda

const int t = coord.gx[3];
const int nt = arg.globalDim3;
real fwd_coeff_3
= Arg::distance_pc ? distanceWeight(arg, t + 1, nt) / distanceWeight(arg, t, nt) : static_cast<real>(1.0);
real bwd_coeff_3
= Arg::distance_pc ? distanceWeight(arg, t - 1, nt) / distanceWeight(arg, t, nt) : static_cast<real>(1.0);
real fwd_coeff_3 = Arg::distance_pc ? static_cast<real>(distanceWeight(arg, t + 1, nt) / distanceWeight(arg, t, nt)) :
static_cast<real>(1.0);
real bwd_coeff_3 = Arg::distance_pc ? static_cast<real>(distanceWeight(arg, t - 1, nt) / distanceWeight(arg, t, nt)) :
static_cast<real>(1.0);

#pragma unroll
for (int d = 0; d < 4; d++) { // loop over dimension - 4 and not nDim since this is used for DWF as well
Expand Down
186 changes: 186 additions & 0 deletions include/kernels/gauge_fix_ovr2.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#include <gauge_field_order.h>
#include <index_helper.cuh>
#include <quda_matrix.h>
#include <kernel.h>
#include <kernels/gauge_utils.cuh>

namespace quda
{

template <typename Float_, int nColor_, QudaReconstructType recon_, int parity_, bool over_relaxation_>
struct FixGaugeArg : kernel_param<> {
using Float = Float_;
static constexpr int nColor = nColor_;
static_assert(nColor == 3, "Only nColor=3 enabled at this time");
static constexpr QudaReconstructType recon = recon_;
static constexpr int parity = parity_;
static constexpr bool over_relaxation = over_relaxation_;
typedef typename gauge_mapper<Float, recon>::type Gauge;

Gauge rot;
const Gauge u;

int X[4]; // grid dimensions
int border[4];
const Float omega;
const int dir_ignore;
const Float tolerance;

FixGaugeArg(GaugeField &rot, const GaugeField &u, double omega, int dir_ignore) :
kernel_param(dim3(u.LocalVolumeCB())),
rot(rot),
u(u),
omega(omega),
dir_ignore(dir_ignore),
tolerance(u.toleranceSU3())
{
for (int dir = 0; dir < 4; ++dir) {
border[dir] = u.R()[dir];
X[dir] = u.X()[dir] - border[dir] * 2;
}
}
};

/**
* @brief Maximize the real trace of UW in SU(2) subgroups and apply
* the result to U. Note it's equivalent to a closest unitary matrix
* problem in SU(2) subgroups.
* U' = \argmax_{U}\mathfrak{Re}\mathrm{Tr}(UV) = \argmin_{U}||U-V^\dagger||_F^2
* Specifically, for GL(2,C) matrices, we have
* U = \frac{V^\dagger}{\sqrt{\det(V^\dagger)}}
* Also accepts the over-relaxation parameter for gauge fixing
*/
template <int su2_index, typename Float, typename Arg>
__host__ __device__ inline void argmaxReTrUW(Matrix<complex<Float>, 3> &U, Matrix<complex<Float>, 3> &W, const Arg &arg)
{
int i1, i2;
switch (su2_index) {
case 0: i1 = 0, i2 = 1; break;
case 1: i1 = 1, i2 = 2; break;
case 2: i1 = 0, i2 = 2; break;
default: break;
}

Matrix<complex<Float>, 3> V = U * W;
double versors[4]; // use double to avoid precision issues

versors[0] = static_cast<double>(V(i1, i1).real() + V(i2, i2).real());
versors[1] = static_cast<double>(V(i1, i1).imag() - V(i2, i2).imag());
versors[2] = static_cast<double>(V(i1, i2).real() - V(i2, i1).real());
versors[3] = static_cast<double>(V(i1, i2).imag() + V(i2, i1).imag());

double norm
= sqrt(versors[0] * versors[0] + versors[1] * versors[1] + versors[2] * versors[2] + versors[3] * versors[3]);
if (norm > arg.tolerance) {
double inv_norm = 1.0 / norm;
versors[0] *= inv_norm;
#pragma unroll
for (int i = 1; i < 4; ++i) { versors[i] *= -inv_norm; }
} else {
versors[0] = 1.0;
#pragma unroll
for (int i = 1; i < 4; ++i) { versors[i] = 0.0; }
}

if constexpr (Arg::over_relaxation) {
// a workaround for fp32 numerical issues
// clamp versors[0] to [-1, 1] to avoid NaN in acos
// if (versors[0] > 1.0) {
// versors[0] = 1.0;
// } else if (versors[0] < -1.0) {
// versors[0] = -1.0;
// }
double angle = acos(versors[0]);
double sin_angle = sin(angle);
double sin_omega_angle, cos_omega_angle;
sincos(arg.omega * angle, &sin_omega_angle, &cos_omega_angle);
versors[0] = cos_omega_angle;
if (sin_angle > arg.tolerance) {
double coeff = sin_omega_angle / sin_angle;
#pragma unroll
for (int i = 1; i < 4; ++i) { versors[i] *= coeff; }
} else {
#pragma unroll
for (int i = 1; i < 4; ++i) { versors[i] = 0.0; }
}
}

setIdentity(&V);
V(i1, i1) = complex(static_cast<Float>(versors[0]), static_cast<Float>(versors[1]));
V(i2, i2) = complex(static_cast<Float>(versors[0]), static_cast<Float>(-versors[1]));
V(i1, i2) = complex(static_cast<Float>(versors[2]), static_cast<Float>(versors[3]));
V(i2, i1) = complex(static_cast<Float>(-versors[2]), static_cast<Float>(versors[3]));

U = V * U;
}

// /**
// * @brief Solving the closest unitary matrix in SU(2) subgroups.
// * This is another way to project the input matrix on the SU(3) group.
// */
// template <typename Float, typename Arg>
// __host__ __device__ inline void closestSu3(Matrix<complex<Float>,3> &in, Float tol)
// {
// Matrix<complex<Float>, 3> out;
// setIdentity(&out);

// constexpr int max_iter = 100;
// int i = 0;
// Float old_retr, retr = getTrace(in).real();
// do { // iterate until matrix is unitary
// // loop over SU(2) subgroup indices
// argmaxReTrUW<0, Float>(out, in);
// argmaxReTrUW<1, Float>(out, in);
// argmaxReTrUW<2, Float>(out, in);
// old_retr = retr;
// retr = getTrace(out * in).real();
// } while (abs(retr - old_retr) / old_retr > tol && ++i < max_iter);

// in = out;
// }

template <typename Arg> struct FixGauge {
const Arg &arg;
constexpr FixGauge(const Arg &arg) : arg(arg) { }
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ inline void operator()(int x_cb)
{
using real = typename Arg::Float;
typedef Matrix<complex<real>, Arg::nColor> Link;
constexpr int parity = Arg::parity;

// compute spacetime and local coords
int x[4], X[4];
#pragma unroll
for (int dr = 0; dr < 4; ++dr) { X[dr] = arg.X[dr]; }
getCoords(x, x_cb, X, parity);
#pragma unroll
for (int dr = 0; dr < 4; ++dr) {
x[dr] += arg.border[dr];
X[dr] += 2 * arg.border[dr];
}

Link g, K, tmp, tmp2;
g = arg.rot(0, linkIndex(x, X), parity);
#pragma unroll
for (int dir = 0; dir < 4; ++dir) {
if (dir != arg.dir_ignore) {
tmp = arg.u(dir, linkIndex(x, X), parity);
tmp2 = arg.rot(0, linkIndexP1(x, X, dir), 1 - parity);
K += tmp * conj(tmp2);
tmp = arg.u(dir, linkIndexM1(x, X, dir), 1 - parity);
tmp2 = arg.rot(0, linkIndexM1(x, X, dir), 1 - parity);
K += conj(tmp2 * tmp);
}
}

// loop over SU(2) subgroup indices
argmaxReTrUW<0, real>(g, K, arg);
argmaxReTrUW<1, real>(g, K, arg);
argmaxReTrUW<2, real>(g, K, arg);

arg.rot(0, linkIndex(x, X), parity) = g;
}
};
} // namespace quda
Loading