11#include < cassert>
22#include < cstring>
3- #include < stdexcept>
43#include " nn-executor.hpp"
54
65void NnFakeNodeSynchronizer::sync (NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) {
@@ -39,6 +38,10 @@ NnExecutorDevice::NnExecutorDevice(NnDevice *device, int segmentFrom, int segmen
3938 this ->segmentTo = segmentTo;
4039}
4140
41+ NnExecutorException::NnExecutorException (const std::string message)
42+ : std::runtime_error(message)
43+ {}
44+
4245NnExecutor::NnExecutor (NnNetConfig *netConfig, NnNodeConfig *nodeConfig, std::vector<NnExecutorDevice> *devices, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark)
4346 : segments(nodeConfig->nSegments), steps()
4447{
@@ -137,13 +140,19 @@ static inline void *executorThreadHandler(void *arg) {
137140 NnUint nThreads = context->nThreads ;
138141 NnUint doneCount = nThreads - 1 ;
139142
140- while (true ) {
143+ while (context-> isAlive . load () ) {
141144 const unsigned int currentStepIndex = context->currentStepIndex .load ();
142145 if (currentStepIndex == context->nSteps )
143146 break ;
144147
145148 NnExecutorStep *step = &context->steps [currentStepIndex];
146- executeStep (step, nThreads, thread, context);
149+ try {
150+ executeStep (step, nThreads, thread, context);
151+ } catch (const std::runtime_error &e) {
152+ context->isAlive .store (false );
153+ printf (" 🚨 Execution error: %s\n " , e.what ());
154+ break ;
155+ }
147156
148157 NnUint currentCount = context->doneThreadCount .fetch_add (1 );
149158 if (currentCount == doneCount) {
@@ -156,7 +165,10 @@ static inline void *executorThreadHandler(void *arg) {
156165 context->doneThreadCount .store (0 );
157166 context->currentStepIndex .fetch_add (1 );
158167 } else {
159- while (context->currentStepIndex .load () == currentStepIndex);
168+ while (
169+ context->currentStepIndex .load () == currentStepIndex &&
170+ context->isAlive .load ()
171+ );
160172 }
161173 }
162174 return nullptr ;
@@ -166,6 +178,7 @@ void NnExecutor::forward() {
166178 assert (netExecution->batchSize > 0 );
167179
168180 NnUint nThreads = netExecution->nThreads ;
181+ context.isAlive .exchange (true );
169182 context.currentStepIndex .exchange (0 );
170183 context.doneThreadCount .exchange (0 );
171184 context.batchSize = netExecution->batchSize ;
@@ -178,12 +191,14 @@ void NnExecutor::forward() {
178191 NnUint threadIndex;
179192 for (threadIndex = 1 ; threadIndex < nThreads; threadIndex++) {
180193 int result = pthread_create (&threads[threadIndex].handler , NULL , (PthreadFunc)executorThreadHandler, (void *)&threads[threadIndex]);
181- if (result != 0 )
182- throw std::runtime_error (" Failed to create thread" );
194+ assert (result == 0 && " Failed to create thread" );
183195 }
184196 executorThreadHandler ((void *)&threads[0 ]);
185197 for (threadIndex = 1 ; threadIndex < nThreads; threadIndex++)
186198 pthread_join (threads[threadIndex].handler , NULL );
199+
200+ if (!context.isAlive .load ())
201+ throw NnExecutorException (" Execution failed in one of the threads" );
187202}
188203
189204NnUint NnExecutor::getTotalTime (NnExecutorStepType type) {
0 commit comments