-
Notifications
You must be signed in to change notification settings - Fork 145
Fix: use gemm instead of einsum in BPCG #5827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: use gemm instead of einsum in BPCG #5827
Conversation
|
There are still some strange issues when running OMP_NUM_THREADS=1 mpirun -np 4 abacusIt oscillates around the wrong result and will not converge in 100 iter steps: START CHARGE : atomic
DONE(0.378312 SEC) : INIT SCF
ITER ETOT/eV EDIFF/eV DRHO TIME/s
BP1 -3.20615851e+03 0.00000000e+00 6.2402e+01 0.95
BP2 -3.80285738e+03 -5.96698874e+02 2.7990e+01 0.17
BP3 -4.10701176e+03 -3.04154385e+02 6.5602e+00 0.18
BP4 -4.04583492e+03 6.11768470e+01 1.1033e+00 0.17
BP5 -4.04704636e+03 -1.21144075e+00 9.9010e-01 0.18
BP6 -4.03059998e+03 1.64463831e+01 9.1813e-01 0.17
BP7 -3.99186247e+03 3.87375006e+01 9.7185e-01 0.17
BP8 -3.96744993e+03 2.44125466e+01 1.1061e+00 0.18
BP9 -3.97584003e+03 -8.39010274e+00 9.6145e-01 0.18
BP10 -3.94152708e+03 3.43129544e+01 9.6275e-01 0.17
BP11 -3.93178473e+03 9.74234358e+00 7.6134e-01 0.18While running by single core it can give the right results in 9 steps: START CHARGE : atomic
DONE(0.442483 SEC) : INIT SCF
ITER ETOT/eV EDIFF/eV DRHO TIME/s
BP1 -4.87034216e+03 0.00000000e+00 1.5454e+00 37.38
BP2 -4.86916145e+03 1.18070831e+00 3.9344e-01 7.28
BP3 -4.86973378e+03 -5.72328822e-01 1.3878e-02 6.04
BP4 -4.86974314e+03 -9.36599264e-03 1.2855e-03 5.89
BP5 -4.86974610e+03 -2.95286226e-03 1.8366e-04 5.72
BP6 -4.86974692e+03 -8.21680864e-04 2.8456e-05 5.53
BP7 -4.86974704e+03 -1.19617951e-04 2.7626e-06 5.93
BP8 -4.86974705e+03 -1.34459124e-05 2.3207e-07 6.37
BP9 -4.86974705e+03 -9.02173064e-07 2.4438e-08 6.34 |
|
When replacing 4 of 5
START CHARGE : atomic
DONE(0.434527 SEC) : INIT SCF
ITER ETOT/eV EDIFF/eV DRHO TIME/s
BP1 -4.86923965e+03 0.00000000e+00 1.5121e+00 35.52
BP2 -4.86283925e+03 6.40040239e+00 4.9555e-01 5.97
BP3 -4.86611767e+03 -3.27841978e+00 1.2539e-02 5.88
BP4 -4.86574893e+03 3.68737277e-01 6.6463e-03 6.07
BP5 -4.86570499e+03 4.39385803e-02 3.7749e-03 7.38
BP6 -4.86494756e+03 7.57431141e-01 3.1324e-03 6.53
BP7 -4.86623319e+03 -1.28563301e+00 1.5995e-03 7.43
BP8 -4.86433126e+03 1.90193658e+00 5.7517e-03 6.98
BP9 -4.86553200e+03 -1.20073878e+00 1.0246e-03 7.06
BP10 -4.86593390e+03 -4.01906932e-01 1.6900e-03 5.75I will go and investigate. |
|
It turns out that the 4th gemm should take // workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);
// gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
gemm_op()(this->ctx,
'N',
'N',
this->n_basis, //m
this->n_band, //n
this->n_band, //k
this->one, //1.0
psi_out.data<T>(),
this->n_basis, //lda
hsub_in.data<T>(),
this->n_band, //ldb
this->zero, //0.0
workspace_in.data<T>(),
this->n_basis); //ldc |
|
|
|
There may be some accuracy difference when migrating from |
|
Inner implementation of Tensor // einsum_op.cpp
// Call the column-major Blas library
kernels::blas_gemm<T, Device>()(
option.conj_y ? 'C' : trans_y ? 'T' : 'N',
option.conj_x ? 'C' : trans_x ? 'T' : 'N',
n, m, k,
&alpha,
y_device_memory_ptrs[0], option.conj_y || trans_y ? k : n,
x_device_memory_ptrs[0], option.conj_x || trans_x ? m : k,
&beta,
z_device_memory_ptrs[0], n);However, ct::EinsumOption option(
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in);
hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);It actually does this calc job: |
|
The result difference may come from doing If this is true, we need to modify the reference data file to adjust to this correction according to new version of BPCG code. |
This reverts commit a3350b7.
|
|
|
New ref will be updated according to results as follows: |
* Add dimension parameter for BPCG method * Add utils for hsovler gemm_op * Change code to fit new bpcg init interface * using gemm instead of einsum in orth_cholesky * using gemm instead of einsum in orth_projection * replace einsum by gemm in orth_projection * replace einsum by gemm in rotate_wf * replace einsum by gemm in diag_hsub * Update 102_PW_BPCG totalstressref reference value


Linked Issue
try to fix #3437
What's changed?
psi.get_current_ngk()to address leading dimension vs valid dimension.Unit Tests and/or Case Tests for my changes
tests/integrate/102_PW_BPCG)result.refneeds to be adjusted to new code output!