@@ -355,18 +355,16 @@ public:
355355 {" input" , inputTensor},
356356 {" labels" , expectedOutputTensor}
357357 },
358- outputTensors,
358+ ( 0 == validationCount) ? outputTensors : std::vector<std::string>() ,
359359 {" optimizerCoefficients" SUSUWU_CNS_IF_BIAS (SUSUWU_COMMA " optimizerBiases" )},
360- &outputs
360+ ( 0 == validationCount) ? &outputs : SUSUWU_NULLPTR
361361 );
362362 if (!status.ok ()) {
363363 throw std::runtime_error (SUSUWU_ERRSTR (SUSUWU_SH_ERROR, getName () + " ::setupSynapses() { const tensorflow::Status status = session->Run({{\" input\" , inputTensor}, {\" labels\" , expectedOutputTensor}, {" + ((0 == validationCount) ? " \" loss\" " : " " ) + " }, {\" optimizerCoefficients\" " + " " SUSUWU_CNS_IF_BIAS (" , \" optimizerBiases\" " ) + " }, &outputs); (!status.ok()) { epoch == " + std::to_string (epoch) + " ; status.ToString() == \" " + status.ToString () + " \" ; } }" ));
364364 }
365- const float trainingLossVal = outputs[0 ].scalar <float >()();
366- SUSUWU_DEBUG (getName () + " ::setupSynapses() { for(size_t epoch = 0; epoch < trainingIterations; ++epoch) { /* training */ epoch == " + std::to_string (epoch) + " ; trainingLossVal == " + std::to_string (trainingLossVal) + " ; status.ToString() == \" " + status.ToString () + " \" ; ... } }" );
367365 float lossVal;
368366 if (0 == validationCount) {
369- lossVal = trainingLossVal ;
367+ lossVal = outputs[ 0 ]. scalar < float >()() ;
370368 } else {
371369 outputs.clear ();
372370 status = session->Run (
@@ -378,9 +376,9 @@ public:
378376 if (!status.ok ()) {
379377 throw std::runtime_error (SUSUWU_ERRSTR (SUSUWU_SH_ERROR, getName () + " ::setupSynapses() { const tensorflow::Status status = session->Run({{\" inputs\" , inputTensor2}, {\" labels\" , expectedOutputTensor2}}, {\" loss\" }, {}, &outputs); (!status.ok()) { epoch == " + std::to_string (epoch) + " ; status.ToString() == \" " + status.ToString () + " \" ; } }" ));
380378 }
381- lossVal = outputs[0 ].scalar <float >()(); /* TODO: use for eager stop */
382- SUSUWU_DEBUG (getName () + " ::setupSynapses() { for(size_t epoch = 0; epoch < trainingIterations; ++epoch) { ... /* validation */ epoch == " + std::to_string (epoch) + " ; lossVal == " + std::to_string (trainingLossVal) + " ; status.ToString() == \" " + status.ToString () + " \" ; ... } }" );
379+ lossVal = outputs[0 ].scalar <float >()();
383380 }
381+ SUSUWU_DEBUG (getName () + " ::setupSynapses() { for(size_t epoch = 0; epoch < trainingIterations; ++epoch) { epoch == " + std::to_string (epoch) + " ; lossVal == " + std::to_string (lossVal) + " ; status.ToString() == \" " + status.ToString () + " \" ; ... } }" );
384382 if (lossVal < desiredLossThreshold) { break ; }
385383 if (lossVal < bestLoss - minLossDelta) {
386384 bestLoss = lossVal;
0 commit comments