@@ -512,8 +512,8 @@ run_model (ENERGYTYPE & dener,
512512 return ;
513513 }
514514
515+ #ifdef USE_CUDA_TOOLKIT
515516 std::vector<Tensor> output_tensors;
516-
517517 checkStatus (session->Run (input_tensors,
518518 {" o_energy" , " o_force" , " o_atom_virial" },
519519 {},
@@ -548,8 +548,37 @@ run_model (ENERGYTYPE & dener,
548548 dvirial[7 ] += 1.0 * datom_virial[9 *ii+7 ];
549549 dvirial[8 ] += 1.0 * datom_virial[9 *ii+8 ];
550550 }
551+
551552 dforce_ = dforce;
552553 nnpmap.backward (dforce_.begin (), dforce.begin (), 3 );
554+ #else
555+ std::vector<Tensor> output_tensors;
556+
557+ checkStatus (session->Run (input_tensors,
558+ {" o_energy" , " o_force" , " o_virial" },
559+ {},
560+ &output_tensors));
561+
562+ Tensor output_e = output_tensors[0 ];
563+ Tensor output_f = output_tensors[1 ];
564+ Tensor output_v = output_tensors[2 ];
565+
566+ auto oe = output_e.flat <ENERGYTYPE> ();
567+ auto of = output_f.flat <VALUETYPE> ();
568+ auto ov = output_v.flat <VALUETYPE> ();
569+
570+ dener = oe (0 );
571+ vector<VALUETYPE> dforce (3 * nall);
572+ dvirial.resize (9 );
573+ for (unsigned ii = 0 ; ii < nall * 3 ; ++ii){
574+ dforce[ii] = of (ii);
575+ }
576+ for (unsigned ii = 0 ; ii < 9 ; ++ii){
577+ dvirial[ii] = ov (ii);
578+ }
579+ dforce_ = dforce;
580+ nnpmap.backward (dforce_.begin (), dforce.begin (), 3 );
581+ #endif
553582}
554583
555584static void run_model (ENERGYTYPE & dener,
@@ -581,7 +610,7 @@ static void run_model (ENERGYTYPE & dener,
581610 fill (datom_virial_.begin (), datom_virial_.end (), 0.0 );
582611 return ;
583612 }
584-
613+ # ifdef USE_CUDA_TOOLKIT
585614 std::vector<Tensor> output_tensors;
586615
587616 checkStatus (session->Run (input_tensors,
@@ -630,6 +659,50 @@ static void run_model (ENERGYTYPE & dener,
630659 nnpmap.backward (dforce_.begin (), dforce.begin (), 3 );
631660 nnpmap.backward (datom_energy_.begin (), datom_energy.begin (), 1 );
632661 nnpmap.backward (datom_virial_.begin (), datom_virial.begin (), 9 );
662+ #else
663+ std::vector<Tensor> output_tensors;
664+
665+ checkStatus (session->Run (input_tensors,
666+ {" o_energy" , " o_force" , " o_virial" , " o_atom_energy" , " o_atom_virial" },
667+ {},
668+ &output_tensors));
669+
670+ Tensor output_e = output_tensors[0 ];
671+ Tensor output_f = output_tensors[1 ];
672+ Tensor output_v = output_tensors[2 ];
673+ Tensor output_ae = output_tensors[3 ];
674+ Tensor output_av = output_tensors[4 ];
675+
676+ auto oe = output_e.flat <ENERGYTYPE> ();
677+ auto of = output_f.flat <VALUETYPE> ();
678+ auto ov = output_v.flat <VALUETYPE> ();
679+ auto oae = output_ae.flat <VALUETYPE> ();
680+ auto oav = output_av.flat <VALUETYPE> ();
681+
682+ dener = oe (0 );
683+ vector<VALUETYPE> dforce (3 * nall);
684+ vector<VALUETYPE> datom_energy (nall, 0 );
685+ vector<VALUETYPE> datom_virial (9 * nall);
686+ dvirial.resize (9 );
687+ for (int ii = 0 ; ii < nall * 3 ; ++ii) {
688+ dforce[ii] = of (ii);
689+ }
690+ for (int ii = 0 ; ii < nloc; ++ii) {
691+ datom_energy[ii] = oae (ii);
692+ }
693+ for (int ii = 0 ; ii < nall * 9 ; ++ii) {
694+ datom_virial[ii] = oav (ii);
695+ }
696+ for (int ii = 0 ; ii < 9 ; ++ii) {
697+ dvirial[ii] = ov (ii);
698+ }
699+ dforce_ = dforce;
700+ datom_energy_ = datom_energy;
701+ datom_virial_ = datom_virial;
702+ nnpmap.backward (dforce_.begin (), dforce.begin (), 3 );
703+ nnpmap.backward (datom_energy_.begin (), datom_energy.begin (), 1 );
704+ nnpmap.backward (datom_virial_.begin (), datom_virial.begin (), 9 );
705+ #endif
633706}
634707
635708static void
0 commit comments