Skip to content

Commit a0420bc

Browse files
committed
implement batch pacing
- simple sleep on tgs per token - sleep and split params on prompt processing (todo: instead of even split consider a max tokens processing batch) - basic implementation of #38, advanced pacing might be needed for future
1 parent 32a1782 commit a0420bc

File tree

4 files changed

+83
-18
lines changed

4 files changed

+83
-18
lines changed

Llama.uplugin

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"FileVersion": 3,
33
"Version": 1,
4-
"VersionName": "0.9.1",
4+
"VersionName": "0.9.2",
55
"FriendlyName": "Llama",
66
"Description": "Llama.cpp plugin for large language model (LLM) inference.",
77
"Category": "LLM",

Source/LlamaCore/Private/Internal/LlamaInternal.cpp

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ bool FLlamaInternal::LoadModelFromParams(const FLLMModelParams& InModelParams)
1414

1515
UE_LOG(LogTemp, Log, TEXT("Device Found: %s %s"), *GPU, *RHI);
1616

17+
LastLoadedParams = InModelParams;
18+
1719
// only print errors
1820
llama_log_set([](enum ggml_log_level level, const char* text, void* /* user_data */) {
1921
if (level >= GGML_LOG_LEVEL_ERROR) {
@@ -451,27 +453,80 @@ int32 FLlamaInternal::ProcessPrompt(const std::string& Prompt, EChatTemplateRole
451453
return NPromptTokens;
452454
}
453455

454-
// prepare a batch for the prompt
455-
llama_batch Batch = llama_batch_get_one(PromptTokens.data(), PromptTokens.size());
456+
//All in one batch
457+
if (LastLoadedParams.Advanced.PromptProcessingPacingSleep == 0.f)
458+
{
459+
// prepare a batch for the prompt
460+
llama_batch Batch = llama_batch_get_one(PromptTokens.data(), PromptTokens.size());
456461

457-
//check sizing before running prompt decode
458-
int NContext = llama_n_ctx(Context);
459-
int NContextUsed = llama_get_kv_cache_used_cells(Context);
462+
//check sizing before running prompt decode
463+
int NContext = llama_n_ctx(Context);
464+
int NContextUsed = llama_get_kv_cache_used_cells(Context);
460465

461-
if (NContextUsed + NPromptTokens > NContext)
462-
{
463-
EmitErrorMessage(FString::Printf(
464-
TEXT("Failed to insert, tried to insert %d tokens to currently used %d tokens which is more than the max %d context size. Try increasing the context size and re-run prompt."),
465-
NPromptTokens, NContextUsed, NContext
466+
if (NContextUsed + NPromptTokens > NContext)
467+
{
468+
EmitErrorMessage(FString::Printf(
469+
TEXT("Failed to insert, tried to insert %d tokens to currently used %d tokens which is more than the max %d context size. Try increasing the context size and re-run prompt."),
470+
NPromptTokens, NContextUsed, NContext
466471
), 22, __func__);
467-
return 0;
468-
}
472+
return 0;
473+
}
469474

470-
// run it through the decode (input)
471-
if (llama_decode(Context, Batch))
475+
// run it through the decode (input)
476+
if (llama_decode(Context, Batch))
477+
{
478+
EmitErrorMessage(TEXT("Failed to decode, could not find a KV slot for the batch (try reducing the size of the batch or increase the context)."), 23, __func__);
479+
return NPromptTokens;
480+
}
481+
}
482+
//Split it and sleep between batches for pacing purposes
483+
else
472484
{
473-
EmitErrorMessage(TEXT("Failed to decode, could not find a KV slot for the batch (try reducing the size of the batch or increase the context)."), 23, __func__);
474-
return NPromptTokens;
485+
int32 BatchCount = LastLoadedParams.Advanced.PromptProcessingPacingSplitN;
486+
487+
int32 TotalTokens = PromptTokens.size();
488+
int32 TokensPerBatch = TotalTokens / BatchCount;
489+
int32 Remainder = TotalTokens % BatchCount;
490+
491+
int32 StartIndex = 0;
492+
493+
for (int32 i = 0; i < BatchCount; i++)
494+
{
495+
// Calculate how many tokens to put in this batch
496+
int32 CurrentBatchSize = TokensPerBatch + (i < Remainder ? 1 : 0);
497+
498+
// Slice the relevant tokens for this batch
499+
std::vector<llama_token> BatchTokens(
500+
PromptTokens.begin() + StartIndex,
501+
PromptTokens.begin() + StartIndex + CurrentBatchSize
502+
);
503+
504+
// Prepare the batch
505+
llama_batch Batch = llama_batch_get_one(BatchTokens.data(), BatchTokens.size());
506+
507+
// Check context before running decode
508+
int NContext = llama_n_ctx(Context);
509+
int NContextUsed = llama_get_kv_cache_used_cells(Context);
510+
511+
if (NContextUsed + BatchTokens.size() > NContext)
512+
{
513+
EmitErrorMessage(FString::Printf(
514+
TEXT("Failed to insert, tried to insert %d tokens to currently used %d tokens which is more than the max %d context size. Try increasing the context size and re-run prompt."),
515+
BatchTokens.size(), NContextUsed, NContext
516+
), 22, __func__);
517+
return 0;
518+
}
519+
520+
// Decode this batch
521+
if (llama_decode(Context, Batch))
522+
{
523+
EmitErrorMessage(TEXT("Failed to decode, could not find a KV slot for the batch (try reducing the size of the batch or increase the context)."), 23, __func__);
524+
return BatchTokens.size();
525+
}
526+
527+
StartIndex += CurrentBatchSize;
528+
FPlatformProcess::Sleep(LastLoadedParams.Advanced.PromptProcessingPacingSleep);
529+
}
475530
}
476531

477532
const auto StopTime = ggml_time_us();
@@ -561,6 +616,12 @@ std::string FLlamaInternal::Generate(const std::string& Prompt, bool bAppendToMe
561616
//Return partial response
562617
return Response;
563618
}
619+
620+
//sleep pacing
621+
if (LastLoadedParams.Advanced.TokenGenerationPacingSleep > 0.f)
622+
{
623+
FPlatformProcess::Sleep(LastLoadedParams.Advanced.TokenGenerationPacingSleep);
624+
}
564625
}
565626

566627
bGenerationActive = false;

Source/LlamaCore/Public/Internal/LlamaInternal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class FLlamaInternal
3535
std::string Template;
3636
std::string TemplateSource;
3737

38-
//Pacing
38+
//Cached params, should be accessed on BT
3939
FLLMModelParams LastLoadedParams;
4040

4141
//Model loading

Source/LlamaCore/Public/LlamaDataTypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ struct FLLMModelAdvancedParams
127127
UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "LLM Model Params")
128128
float PromptProcessingPacingSleep = 0.f;
129129

130+
//this part is only active if PromptProcessingPacingSleep > 0.f. Splits prompts into n chunks with sleep
131+
UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "LLM Model Params")
132+
int32 PromptProcessingPacingSplitN = 4;
133+
130134
//usually . ? !
131135
UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "LLM Model Params")
132136
TArray<FString> PartialsSeparators;

0 commit comments

Comments
 (0)