Skip to content

Commit 6b3a6c6

Browse files
LuLu
authored andcommitted
fix bugs, support none-cuda float precision
1 parent 58ab3f6 commit 6b3a6c6

File tree

1 file changed

+75
-2
lines changed

1 file changed

+75
-2
lines changed

source/lib/src/NNPInter.cc

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,8 @@ run_model (ENERGYTYPE & dener,
512512
return;
513513
}
514514

515+
#ifdef USE_CUDA_TOOLKIT
515516
std::vector<Tensor> output_tensors;
516-
517517
checkStatus (session->Run(input_tensors,
518518
{"o_energy", "o_force", "o_atom_virial"},
519519
{},
@@ -548,8 +548,37 @@ run_model (ENERGYTYPE & dener,
548548
dvirial[7] += 1.0 * datom_virial[9*ii+7];
549549
dvirial[8] += 1.0 * datom_virial[9*ii+8];
550550
}
551+
551552
dforce_ = dforce;
552553
nnpmap.backward (dforce_.begin(), dforce.begin(), 3);
554+
#else
555+
std::vector<Tensor> output_tensors;
556+
557+
checkStatus (session->Run(input_tensors,
558+
{"o_energy", "o_force", "o_virial"},
559+
{},
560+
&output_tensors));
561+
562+
Tensor output_e = output_tensors[0];
563+
Tensor output_f = output_tensors[1];
564+
Tensor output_v = output_tensors[2];
565+
566+
auto oe = output_e.flat <ENERGYTYPE> ();
567+
auto of = output_f.flat <VALUETYPE> ();
568+
auto ov = output_v.flat <VALUETYPE> ();
569+
570+
dener = oe(0);
571+
vector<VALUETYPE> dforce (3 * nall);
572+
dvirial.resize (9);
573+
for (unsigned ii = 0; ii < nall * 3; ++ii){
574+
dforce[ii] = of(ii);
575+
}
576+
for (unsigned ii = 0; ii < 9; ++ii){
577+
dvirial[ii] = ov(ii);
578+
}
579+
dforce_ = dforce;
580+
nnpmap.backward (dforce_.begin(), dforce.begin(), 3);
581+
#endif
553582
}
554583

555584
static void run_model (ENERGYTYPE & dener,
@@ -581,7 +610,7 @@ static void run_model (ENERGYTYPE & dener,
581610
fill(datom_virial_.begin(), datom_virial_.end(), 0.0);
582611
return;
583612
}
584-
613+
#ifdef USE_CUDA_TOOLKIT
585614
std::vector<Tensor> output_tensors;
586615

587616
checkStatus (session->Run(input_tensors,
@@ -630,6 +659,50 @@ static void run_model (ENERGYTYPE & dener,
630659
nnpmap.backward (dforce_.begin(), dforce.begin(), 3);
631660
nnpmap.backward (datom_energy_.begin(), datom_energy.begin(), 1);
632661
nnpmap.backward (datom_virial_.begin(), datom_virial.begin(), 9);
662+
#else
663+
std::vector<Tensor> output_tensors;
664+
665+
checkStatus (session->Run(input_tensors,
666+
{"o_energy", "o_force", "o_virial", "o_atom_energy", "o_atom_virial"},
667+
{},
668+
&output_tensors));
669+
670+
Tensor output_e = output_tensors[0];
671+
Tensor output_f = output_tensors[1];
672+
Tensor output_v = output_tensors[2];
673+
Tensor output_ae = output_tensors[3];
674+
Tensor output_av = output_tensors[4];
675+
676+
auto oe = output_e.flat <ENERGYTYPE> ();
677+
auto of = output_f.flat <VALUETYPE> ();
678+
auto ov = output_v.flat <VALUETYPE> ();
679+
auto oae = output_ae.flat <VALUETYPE> ();
680+
auto oav = output_av.flat <VALUETYPE> ();
681+
682+
dener = oe(0);
683+
vector<VALUETYPE> dforce (3 * nall);
684+
vector<VALUETYPE> datom_energy (nall, 0);
685+
vector<VALUETYPE> datom_virial (9 * nall);
686+
dvirial.resize (9);
687+
for (int ii = 0; ii < nall * 3; ++ii) {
688+
dforce[ii] = of(ii);
689+
}
690+
for (int ii = 0; ii < nloc; ++ii) {
691+
datom_energy[ii] = oae(ii);
692+
}
693+
for (int ii = 0; ii < nall * 9; ++ii) {
694+
datom_virial[ii] = oav(ii);
695+
}
696+
for (int ii = 0; ii < 9; ++ii) {
697+
dvirial[ii] = ov(ii);
698+
}
699+
dforce_ = dforce;
700+
datom_energy_ = datom_energy;
701+
datom_virial_ = datom_virial;
702+
nnpmap.backward (dforce_.begin(), dforce.begin(), 3);
703+
nnpmap.backward (datom_energy_.begin(), datom_energy.begin(), 1);
704+
nnpmap.backward (datom_virial_.begin(), datom_virial.begin(), 9);
705+
#endif
633706
}
634707

635708
static void

0 commit comments

Comments
 (0)