Skip to content

Commit 3f37301

Browse files
committed
@TensorFlowCns::setupSynapses; debug reduced
Purpose: since `session->Run` can have order-of-ops issues, try not to reuse training step to compute `loss`. Is followup to commit HEAD~1 (@`TensorFlowCns::setupSynapses`; debug print loss). TODO: `squash` this
1 parent 481095d commit 3f37301

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

cxx/ClassTensorFlowCns.hxx

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -355,18 +355,16 @@ public:
355355
{"input", inputTensor},
356356
{"labels", expectedOutputTensor}
357357
},
358-
outputTensors,
358+
(0 == validationCount) ? outputTensors : std::vector<std::string>(),
359359
{"optimizerCoefficients" SUSUWU_CNS_IF_BIAS(SUSUWU_COMMA "optimizerBiases")},
360-
&outputs
360+
(0 == validationCount) ? &outputs : SUSUWU_NULLPTR
361361
);
362362
if(!status.ok()) {
363363
throw std::runtime_error(SUSUWU_ERRSTR(SUSUWU_SH_ERROR, getName() + "::setupSynapses() { const tensorflow::Status status = session->Run({{\"input\", inputTensor}, {\"labels\", expectedOutputTensor}, {" + ((0 == validationCount) ? "\"loss\"" : "") + "}, {\"optimizerCoefficients\" " + "" SUSUWU_CNS_IF_BIAS(", \"optimizerBiases\"") + "}, &outputs); (!status.ok()) { epoch == " + std::to_string(epoch) + "; status.ToString() == \"" + status.ToString() + "\"; } }"));
364364
}
365-
const float trainingLossVal = outputs[0].scalar<float>()();
366-
SUSUWU_DEBUG(getName() + "::setupSynapses() { for(size_t epoch = 0; epoch < trainingIterations; ++epoch) { /* training */ epoch == " + std::to_string(epoch) + "; trainingLossVal == " + std::to_string(trainingLossVal) + "; status.ToString() == \"" + status.ToString() + "\"; ... } }");
367365
float lossVal;
368366
if(0 == validationCount) {
369-
lossVal = trainingLossVal;
367+
lossVal = outputs[0].scalar<float>()();
370368
} else {
371369
outputs.clear();
372370
status = session->Run(
@@ -378,9 +376,9 @@ public:
378376
if(!status.ok()) {
379377
throw std::runtime_error(SUSUWU_ERRSTR(SUSUWU_SH_ERROR, getName() + "::setupSynapses() { const tensorflow::Status status = session->Run({{\"inputs\", inputTensor2}, {\"labels\", expectedOutputTensor2}}, {\"loss\"}, {}, &outputs); (!status.ok()) { epoch == " + std::to_string(epoch) + "; status.ToString() == \"" + status.ToString() + "\"; } }"));
380378
}
381-
lossVal = outputs[0].scalar<float>()(); /* TODO: use for eager stop */
382-
SUSUWU_DEBUG(getName() + "::setupSynapses() { for(size_t epoch = 0; epoch < trainingIterations; ++epoch) { ... /* validation */ epoch == " + std::to_string(epoch) + "; lossVal == " + std::to_string(trainingLossVal) + "; status.ToString() == \"" + status.ToString() + "\"; ... } }");
379+
lossVal = outputs[0].scalar<float>()();
383380
}
381+
SUSUWU_DEBUG(getName() + "::setupSynapses() { for(size_t epoch = 0; epoch < trainingIterations; ++epoch) { epoch == " + std::to_string(epoch) + "; lossVal == " + std::to_string(lossVal) + "; status.ToString() == \"" + status.ToString() + "\"; ... } }");
384382
if(lossVal < desiredLossThreshold) { break; }
385383
if(lossVal < bestLoss - minLossDelta) {
386384
bestLoss = lossVal;

0 commit comments

Comments
 (0)