Skip to content

Commit 96c661e

Browse files
authored
feat: improved api reliability. (#268)
1 parent 0ec7340 commit 96c661e

File tree

11 files changed

+183
-111
lines changed

11 files changed

+183
-111
lines changed

src/app.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ static ChatTemplateType parseChatTemplateType(char *val) {
2323

2424
AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
2525
AppCliArgs args;
26+
args.info = true;
2627
args.help = false;
2728
args.mode = nullptr;
2829
args.nBatches = 32;
@@ -236,7 +237,7 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
236237
throw std::runtime_error("This version supports only Q40 weights with Q80 sync type");
237238

238239
Tokenizer tokenizer(args->tokenizerPath);
239-
if (tokenizer.vocabSize != header.vocabSize)
240+
if (args->info && tokenizer.vocabSize != header.vocabSize)
240241
printf("Tokenizer vocab size (%d) does not match the model vocab size (%d)\n", tokenizer.vocabSize, header.vocabSize);
241242

242243
Sampler sampler(tokenizer.vocabSize, args->temperature, args->topp, args->seed);
@@ -246,8 +247,11 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
246247

247248
NnNodeConfig *rootNodeConfig = &net.nodeConfigs[0];
248249

249-
printLlmHeader(&header);
250-
printNodeRequiredMemory(&net.netConfig, rootNodeConfig);
250+
if (args->info) {
251+
tokenizer.printHeader();
252+
printLlmHeader(&header);
253+
printNodeRequiredMemory(&net.netConfig, rootNodeConfig);
254+
}
251255

252256
NnNetExecution execution(args->nThreads, &net.netConfig);
253257

@@ -346,11 +350,11 @@ void runWorkerApp(AppCliArgs *args) {
346350
}
347351
executor.forward();
348352
isFirstAttempt = true;
349-
} catch (const NnReadNetworkException &e) {
350-
printf("Read network exception: %s\n", e.message);
353+
} catch (const NnTransferSocketException &e) {
354+
printf("🚨 Network error: %s\n", e.what());
351355
break;
352-
} catch (const NnWriteNetworkException &e) {
353-
printf("Write network exception: %s\n", e.message);
356+
} catch (const NnExecutorException &e) {
357+
printf("🚨 Inference error: %s\n", e.what());
354358
break;
355359
}
356360
}

src/app.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class AppCliArgs {
1212
char *mode;
1313
NnUint nThreads;
1414
NnUint nBatches;
15+
bool info;
1516
bool help;
1617

1718
// inference

src/dllama-api.cpp

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <vector>
1010
#include <string>
1111
#include <csignal>
12+
#include <thread>
13+
#include <chrono>
1214

1315
#ifdef _WIN32
1416
#include <winsock2.h>
@@ -532,7 +534,7 @@ void handleModelsRequest(HttpRequest& request, const char* modelPath) {
532534
}
533535

534536
static void server(AppInferenceContext *context) {
535-
int serverSocket = createServerSocket(context->args->port);
537+
NnSocket serverSocket(createServerSocket(context->args->port));
536538

537539
TokenizerChatStops stops(context->tokenizer);
538540
ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
@@ -556,19 +558,16 @@ static void server(AppInferenceContext *context) {
556558

557559
while (true) {
558560
try {
559-
int clientSocket = acceptSocket(serverSocket);
560-
HttpRequest request = HttpRequest::read(clientSocket);
561+
NnSocket clientSocket(acceptSocket(serverSocket.fd));
562+
HttpRequest request = HttpRequest::read(clientSocket.fd);
561563
printf("🔷 %s %s\n", request.getMethod().c_str(), request.path.c_str());
562564
Router::resolve(request, routes);
563-
destroySocket(clientSocket);
564-
} catch (NnReadNetworkException& ex) {
565-
printf("Read socket error: %d %s\n", ex.code, ex.message);
566-
} catch (NnWriteNetworkException& ex) {
567-
printf("Write socket error: %d %s\n", ex.code, ex.message);
565+
} catch (const NnTransferSocketException& e) {
566+
printf("Socket error: %d %s\n", e.code, e.what());
567+
} catch (const NnExecutorException &e) {
568+
throw;
568569
}
569570
}
570-
571-
destroySocket(serverSocket);
572571
}
573572

574573
#ifdef _WIN32
@@ -601,22 +600,29 @@ int main(int argc, char *argv[]) {
601600
std::signal(SIGPIPE, SIG_IGN);
602601
#endif
603602

603+
AppCliArgs args = AppCliArgs::parse(argc, argv, false);
604+
if (args.help) {
605+
usage();
606+
return EXIT_SUCCESS;
607+
}
608+
604609
initQuants();
605610
initSockets();
606611

607-
int returnCode = EXIT_SUCCESS;
608-
try {
609-
AppCliArgs args = AppCliArgs::parse(argc, argv, false);
610-
if (args.help) {
611-
usage();
612-
} else {
612+
while (true) {
613+
try {
613614
runInferenceApp(&args, server);
615+
} catch (const NnConnectionSocketException &e) {
616+
printf("🚨 Connection error: %s\n", e.what());
617+
} catch (const NnExecutorException &e) {
618+
printf("🚨 Inference error: %s\n", e.what());
614619
}
615-
} catch (std::exception &e) {
616-
printf("🚨 Critical error: %s\n", e.what());
617-
returnCode = EXIT_FAILURE;
620+
621+
printf("🔄 Retrying in 3 seconds...\n");
622+
std::this_thread::sleep_for(std::chrono::seconds(3));
623+
args.info = false;
618624
}
619625

620626
cleanupSockets();
621-
return returnCode;
627+
return EXIT_SUCCESS;
622628
}

src/dllama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ int main(int argc, char **argv) {
275275
runWorkerApp(&args);
276276
else
277277
throw std::runtime_error("Unsupported mode");
278-
} catch (std::exception &e) {
278+
} catch (const std::exception &e) {
279279
printf("🚨 Critical error: %s\n", e.what());
280280
returnCode = EXIT_FAILURE;
281281
}

src/nn/nn-core.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ void releaseNetConfig(NnNetConfig *netConfig) {
145145
for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) {
146146
delete[] netConfig->pipes[pipeIndex].name;
147147
}
148+
if (netConfig->nPreSyncs > 0)
149+
delete[] netConfig->preSyncs;
148150
delete[] netConfig->pipes;
149151
}
150152

src/nn/nn-executor.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <cassert>
22
#include <cstring>
3-
#include <stdexcept>
43
#include "nn-executor.hpp"
54

65
void NnFakeNodeSynchronizer::sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) {
@@ -39,6 +38,10 @@ NnExecutorDevice::NnExecutorDevice(NnDevice *device, int segmentFrom, int segmen
3938
this->segmentTo = segmentTo;
4039
}
4140

41+
NnExecutorException::NnExecutorException(const std::string message)
42+
: std::runtime_error(message)
43+
{}
44+
4245
NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, std::vector<NnExecutorDevice> *devices, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark)
4346
: segments(nodeConfig->nSegments), steps()
4447
{
@@ -137,13 +140,19 @@ static inline void *executorThreadHandler(void *arg) {
137140
NnUint nThreads = context->nThreads;
138141
NnUint doneCount = nThreads - 1;
139142

140-
while (true) {
143+
while (context->isAlive.load()) {
141144
const unsigned int currentStepIndex = context->currentStepIndex.load();
142145
if (currentStepIndex == context->nSteps)
143146
break;
144147

145148
NnExecutorStep *step = &context->steps[currentStepIndex];
146-
executeStep(step, nThreads, thread, context);
149+
try {
150+
executeStep(step, nThreads, thread, context);
151+
} catch (const std::runtime_error &e) {
152+
context->isAlive.store(false);
153+
printf("🚨 Execution error: %s\n", e.what());
154+
break;
155+
}
147156

148157
NnUint currentCount = context->doneThreadCount.fetch_add(1);
149158
if (currentCount == doneCount) {
@@ -156,7 +165,10 @@ static inline void *executorThreadHandler(void *arg) {
156165
context->doneThreadCount.store(0);
157166
context->currentStepIndex.fetch_add(1);
158167
} else {
159-
while (context->currentStepIndex.load() == currentStepIndex);
168+
while (
169+
context->currentStepIndex.load() == currentStepIndex &&
170+
context->isAlive.load()
171+
);
160172
}
161173
}
162174
return nullptr;
@@ -166,6 +178,7 @@ void NnExecutor::forward() {
166178
assert(netExecution->batchSize > 0);
167179

168180
NnUint nThreads = netExecution->nThreads;
181+
context.isAlive.exchange(true);
169182
context.currentStepIndex.exchange(0);
170183
context.doneThreadCount.exchange(0);
171184
context.batchSize = netExecution->batchSize;
@@ -178,12 +191,14 @@ void NnExecutor::forward() {
178191
NnUint threadIndex;
179192
for (threadIndex = 1; threadIndex < nThreads; threadIndex++) {
180193
int result = pthread_create(&threads[threadIndex].handler, NULL, (PthreadFunc)executorThreadHandler, (void *)&threads[threadIndex]);
181-
if (result != 0)
182-
throw std::runtime_error("Failed to create thread");
194+
assert(result == 0 && "Failed to create thread");
183195
}
184196
executorThreadHandler((void *)&threads[0]);
185197
for (threadIndex = 1; threadIndex < nThreads; threadIndex++)
186198
pthread_join(threads[threadIndex].handler, NULL);
199+
200+
if (!context.isAlive.load())
201+
throw NnExecutorException("Execution failed in one of the threads");
187202
}
188203

189204
NnUint NnExecutor::getTotalTime(NnExecutorStepType type) {

src/nn/nn-executor.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "nn-core.hpp"
55
#include <atomic>
66
#include <vector>
7+
#include <stdexcept>
78
#include "pthread.h"
89

910
class NnDeviceSegment {
@@ -73,6 +74,7 @@ typedef struct {
7374
NnNodeSynchronizer *synchronizer;
7475
std::atomic_uint currentStepIndex;
7576
std::atomic_uint doneThreadCount;
77+
std::atomic_bool isAlive;
7678
NnUint batchSize;
7779
Timer *timer;
7880
NnUint totalTime[N_STEP_TYPES];
@@ -84,6 +86,11 @@ typedef struct {
8486
PthreadHandler handler;
8587
} NnExecutorThread;
8688

89+
class NnExecutorException : public std::runtime_error {
90+
public:
91+
NnExecutorException(const std::string message);
92+
};
93+
8794
class NnExecutor {
8895
private:
8996
NnNetExecution *netExecution;

0 commit comments

Comments
 (0)