Skip to content

Commit b06036c

Browse files
committed
Merge branch 'feat-b4958'
2 parents 16c17c2 + 420a2b8 commit b06036c

File tree

13 files changed

+310
-55
lines changed

13 files changed

+310
-55
lines changed

Llama.uplugin

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"FileVersion": 3,
33
"Version": 1,
4-
"VersionName": "0.9.2",
4+
"VersionName": "0.9.4",
55
"FriendlyName": "Llama",
66
"Description": "Llama.cpp plugin for large language model (LLM) inference.",
77
"Category": "LLM",

Source/LlamaCore/Private/Internal/LlamaInternal.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ int32 FLlamaInternal::UsedContext()
273273
{
274274
if (Context)
275275
{
276-
return llama_get_kv_cache_used_cells(Context);
276+
return llama_kv_self_used_cells(Context);
277277
}
278278
else
279279
{
@@ -313,16 +313,16 @@ void FLlamaInternal::ResetContextHistory(bool bKeepSystemsPrompt)
313313
ContextHistory.clear();
314314
Messages.clear();
315315

316-
llama_kv_cache_clear(Context);
316+
llama_kv_self_clear(Context);
317317
FilledContextCharLength = 0;
318318
}
319319

320320
void FLlamaInternal::RollbackContextHistoryByTokens(int32 NTokensToErase)
321321
{
322322
// clear the last n_regen tokens from the KV cache and update n_past
323-
int32 TokensUsed = llama_get_kv_cache_used_cells(Context); //FilledContextCharLength
323+
int32 TokensUsed = llama_kv_self_used_cells(Context); //FilledContextCharLength
324324

325-
llama_kv_cache_seq_rm(Context, 0, TokensUsed - NTokensToErase, -1);
325+
llama_kv_self_seq_rm(Context, 0, TokensUsed - NTokensToErase, -1);
326326

327327
//FilledContextCharLength -= NTokensToErase;
328328

@@ -442,7 +442,7 @@ int32 FLlamaInternal::ProcessPrompt(const std::string& Prompt, EChatTemplateRole
442442

443443
//Grab vocab
444444
const llama_vocab* Vocab = llama_model_get_vocab(LlamaModel);
445-
const bool IsFirst = llama_get_kv_cache_used_cells(Context) == 0;
445+
const bool IsFirst = llama_kv_self_used_cells(Context) == 0;
446446

447447
// tokenize the prompt
448448
const int NPromptTokens = -llama_tokenize(Vocab, Prompt.c_str(), Prompt.size(), NULL, 0, IsFirst, true);
@@ -461,7 +461,7 @@ int32 FLlamaInternal::ProcessPrompt(const std::string& Prompt, EChatTemplateRole
461461

462462
//check sizing before running prompt decode
463463
int NContext = llama_n_ctx(Context);
464-
int NContextUsed = llama_get_kv_cache_used_cells(Context);
464+
int NContextUsed = llama_kv_self_used_cells(Context);
465465

466466
if (NContextUsed + NPromptTokens > NContext)
467467
{
@@ -506,7 +506,7 @@ int32 FLlamaInternal::ProcessPrompt(const std::string& Prompt, EChatTemplateRole
506506

507507
// Check context before running decode
508508
int NContext = llama_n_ctx(Context);
509-
int NContextUsed = llama_get_kv_cache_used_cells(Context);
509+
int NContextUsed = llama_kv_self_used_cells(Context);
510510

511511
if (NContextUsed + BatchTokens.size() > NContext)
512512
{
@@ -563,7 +563,7 @@ std::string FLlamaInternal::Generate(const std::string& Prompt, bool bAppendToMe
563563

564564
// check if we have enough space in the context to evaluate this batch - might need to be inside loop
565565
int NContext = llama_n_ctx(Context);
566-
int NContextUsed = llama_get_kv_cache_used_cells(Context);
566+
int NContextUsed = llama_kv_self_used_cells(Context);
567567
bool bEOGExit = false;
568568

569569
while (bGenerationActive) //processing can be aborted by flipping the boolean

Source/LlamaCore/Private/LlamaComponent.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ ULlamaComponent::ULlamaComponent(const FObjectInitializer &ObjectInitializer)
1919
OnTokenGenerated.Broadcast(Token);
2020
};
2121

22+
LlamaNative->OnResponseGenerated = [this](const FString& Response)
23+
{
24+
OnResponseGenerated.Broadcast(Response);
25+
OnEndOfStream.Broadcast(true, ModelState.LastTokenGenerationSpeed);
26+
};
27+
2228
LlamaNative->OnPartialGenerated = [this](const FString& Partial)
2329
{
2430
OnPartialGenerated.Broadcast(Partial);
@@ -87,26 +93,26 @@ void ULlamaComponent::InsertTemplatedPrompt(const FString& Text, EChatTemplateRo
8793

8894
void ULlamaComponent::InsertTemplatedPromptStruct(const FLlamaChatPrompt& ChatPrompt)
8995
{
90-
LlamaNative->InsertTemplatedPrompt(ChatPrompt, [this, ChatPrompt](const FString& Response)
91-
{
96+
LlamaNative->InsertTemplatedPrompt(ChatPrompt);/*, [this, ChatPrompt](const FString& Response));
97+
{
9298
if (ChatPrompt.bGenerateReply)
9399
{
94100
OnResponseGenerated.Broadcast(Response);
95101
OnEndOfStream.Broadcast(true, ModelState.LastTokenGenerationSpeed);
96102
}
97-
});
103+
});*/
98104
}
99105

100106
void ULlamaComponent::InsertRawPrompt(const FString& Text, bool bGenerateReply)
101107
{
102-
LlamaNative->InsertRawPrompt(Text, bGenerateReply, [this, bGenerateReply](const FString& Response)
108+
LlamaNative->InsertRawPrompt(Text, bGenerateReply); /*, [this, bGenerateReply](const FString& Response)
103109
{
104110
if (bGenerateReply)
105111
{
106112
OnResponseGenerated.Broadcast(Response);
107113
OnEndOfStream.Broadcast(true, ModelState.LastTokenGenerationSpeed);
108114
}
109-
});
115+
});*/
110116
}
111117

112118
void ULlamaComponent::LoadModel(bool bForceReload)
@@ -120,11 +126,6 @@ void ULlamaComponent::LoadModel(bool bForceReload)
120126
return;
121127
}
122128

123-
if (ModelParams.bAutoInsertSystemPromptOnLoad)
124-
{
125-
InsertTemplatedPrompt(ModelParams.SystemPrompt, EChatTemplateRole::System, false, false);
126-
}
127-
128129
OnModelLoaded.Broadcast(ModelPath);
129130
});
130131
}
@@ -163,6 +164,19 @@ void ULlamaComponent::RemoveLastUserInput()
163164
LlamaNative->RemoveLastUserInput();
164165
}
165166

167+
168+
void ULlamaComponent::ImpersonateTemplatedPrompt(const FLlamaChatPrompt& ChatPrompt)
169+
{
170+
LlamaNative->SetModelParams(ModelParams);
171+
172+
LlamaNative->ImpersonateTemplatedPrompt(ChatPrompt);
173+
}
174+
175+
void ULlamaComponent::ImpersonateTemplatedToken(const FString& Token, EChatTemplateRole Role, bool bEoS)
176+
{
177+
LlamaNative->ImpersonateTemplatedToken(Token, Role, bEoS);
178+
}
179+
166180
FString ULlamaComponent::WrapPromptForRole(const FString& Text, EChatTemplateRole Role, const FString& Template)
167181
{
168182
return LlamaNative->WrapPromptForRole(Text, Role, Template);

Source/LlamaCore/Private/LlamaNative.cpp

Lines changed: 145 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ FLlamaNative::FLlamaNative()
6060
{
6161
if (ModelParams.Advanced.bLogGenerationStats)
6262
{
63-
UE_LOG(LlamaLog, Log, TEXT("Generated %d tokens in %1.2fs (%1.2ftps)"), TokensGenerated, Duration, SpeedTps);
63+
UE_LOG(LlamaLog, Log, TEXT("TGS - Generated %d tokens in %1.2fs (%1.2ftps)"), TokensGenerated, Duration, SpeedTps);
6464
}
6565

6666
int32 UsedContext = UsedContextLength();
@@ -88,6 +88,11 @@ FLlamaNative::FLlamaNative()
8888

8989
Internal->OnPromptProcessed = [this](int32 TokensProcessed, EChatTemplateRole RoleProcessed, float SpeedTps)
9090
{
91+
if (ModelParams.Advanced.bLogGenerationStats)
92+
{
93+
UE_LOG(LlamaLog, Log, TEXT("PPS - Processed %d tokens at %1.2ftps"), TokensProcessed, SpeedTps);
94+
}
95+
9196
int32 UsedContext = UsedContextLength();
9297

9398
//Sync history data with additional state updates
@@ -266,6 +271,7 @@ void FLlamaNative::LoadModel(bool bForceReload, TFunction<void(const FString&, i
266271
//already loaded, we're done
267272
return ModelLoadedCallback(ModelParams.PathToModel, 0);
268273
}
274+
bModelLoadInitiated = true;
269275

270276
//Copy so these dont get modified during enqueue op
271277
const FLLMModelParams ParamsAtLoad = ModelParams;
@@ -284,6 +290,14 @@ void FLlamaNative::LoadModel(bool bForceReload, TFunction<void(const FString&, i
284290
const FString TemplateString = FLlamaString::ToUE(Internal->Template);
285291
const FString TemplateSource = FLlamaString::ToUE(Internal->TemplateSource);
286292

293+
//Before we release the BG thread, ensure we enqueue the system prompt
294+
//If we do it later, other queued calls will frontrun it. This enables startup chaining correctly
295+
if (ParamsAtLoad.bAutoInsertSystemPromptOnLoad)
296+
{
297+
Internal->InsertTemplatedPrompt(FLlamaString::ToStd(ParamsAtLoad.SystemPrompt), EChatTemplateRole::System, false, false);
298+
}
299+
300+
//Callback on game thread for data sync
287301
EnqueueGTTask([this, TemplateString, TemplateSource, ModelLoadedCallback]
288302
{
289303
FJinjaChatTemplate ChatTemplate;
@@ -293,6 +307,8 @@ void FLlamaNative::LoadModel(bool bForceReload, TFunction<void(const FString&, i
293307
ModelState.ChatTemplateInUse = ChatTemplate;
294308
ModelState.bModelIsLoaded = true;
295309

310+
bModelLoadInitiated = false;
311+
296312
if (OnModelStateChanged)
297313
{
298314
OnModelStateChanged(ModelState);
@@ -308,15 +324,22 @@ void FLlamaNative::LoadModel(bool bForceReload, TFunction<void(const FString&, i
308324
{
309325
EnqueueGTTask([this, ModelLoadedCallback]
310326
{
327+
bModelLoadInitiated = false;
328+
311329
//On error will be triggered earlier in the chain, but forward our model loading error status here
312-
ModelLoadedCallback(ModelParams.PathToModel, 15);
330+
if (ModelLoadedCallback)
331+
{
332+
ModelLoadedCallback(ModelParams.PathToModel, 15);
333+
}
313334
}, TaskId);
314335
}
315336
});
316337
}
317338

318339
void FLlamaNative::UnloadModel(TFunction<void(int32 StatusCode)> ModelUnloadedCallback)
319340
{
341+
bModelLoadInitiated = false;
342+
320343
EnqueueBGTask([this, ModelUnloadedCallback](int64 TaskId)
321344
{
322345
if (IsModelLoaded())
@@ -349,7 +372,7 @@ bool FLlamaNative::IsModelLoaded()
349372

350373
void FLlamaNative::InsertTemplatedPrompt(const FLlamaChatPrompt& Prompt, TFunction<void(const FString& Response)> OnResponseFinished)
351374
{
352-
if (!IsModelLoaded())
375+
if (!IsModelLoaded() && !bModelLoadInitiated)
353376
{
354377
UE_LOG(LlamaLog, Warning, TEXT("Model isn't loaded, can't run prompt."));
355378
return;
@@ -386,7 +409,7 @@ void FLlamaNative::InsertTemplatedPrompt(const FLlamaChatPrompt& Prompt, TFuncti
386409

387410
void FLlamaNative::InsertRawPrompt(const FString& Prompt, bool bGenerateReply, TFunction<void(const FString& Response)>OnResponseFinished)
388411
{
389-
if (!IsModelLoaded())
412+
if (!IsModelLoaded() && !bModelLoadInitiated)
390413
{
391414
UE_LOG(LlamaLog, Warning, TEXT("Model isn't loaded, can't run prompt."));
392415
return;
@@ -407,6 +430,124 @@ void FLlamaNative::InsertRawPrompt(const FString& Prompt, bool bGenerateReply, T
407430
});
408431
}
409432

433+
void FLlamaNative::ImpersonateTemplatedPrompt(const FLlamaChatPrompt& Prompt)
434+
{
435+
//modify model state
436+
if (IsModelLoaded())
437+
{
438+
//insert it but make sure we don't do any token generation
439+
FLlamaChatPrompt ModifiedPrompt = Prompt;
440+
ModifiedPrompt.bGenerateReply = false;
441+
442+
InsertTemplatedPrompt(ModifiedPrompt);
443+
}
444+
else
445+
{
446+
//no model, so just run this in sync mode
447+
FStructuredChatMessage Message;
448+
Message.Role = Prompt.Role;
449+
Message.Content = Prompt.Prompt;
450+
451+
//modify our chat history state
452+
ModelState.ChatHistory.History.Add(Message);
453+
454+
if (OnModelStateChanged)
455+
{
456+
OnModelStateChanged(ModelState);
457+
}
458+
//was this an assistant message? emit response generated callback
459+
if (Message.Role == EChatTemplateRole::Assistant)
460+
{
461+
if (OnResponseGenerated)
462+
{
463+
OnResponseGenerated(Prompt.Prompt);
464+
}
465+
}
466+
}
467+
}
468+
469+
void FLlamaNative::ImpersonateTemplatedToken(const FString& Token, EChatTemplateRole Role, bool bEoS)
470+
{
471+
//Should be called on game thread.
472+
473+
//NB: we don't support updating model internal state atm
474+
475+
//Check if we need to add a message before modifying it
476+
bool bLastRoleWasMatchingRole = false;
477+
478+
if (ModelState.ChatHistory.History.Num() != 0)
479+
{
480+
FStructuredChatMessage& Message = ModelState.ChatHistory.History.Last();
481+
bLastRoleWasMatchingRole = Message.Role == Role;
482+
}
483+
484+
FString CurrentReplyText;
485+
486+
if (!bLastRoleWasMatchingRole)
487+
{
488+
FStructuredChatMessage Message;
489+
Message.Role = Role;
490+
Message.Content = Token;
491+
492+
ModelState.ChatHistory.History.Add(Message);
493+
494+
CurrentReplyText += Token;
495+
}
496+
else
497+
{
498+
FStructuredChatMessage& Message = ModelState.ChatHistory.History.Last();
499+
Message.Content += Token;
500+
501+
CurrentReplyText += Message.Content;
502+
}
503+
504+
FStructuredChatMessage& Message = ModelState.ChatHistory.History.Last();
505+
506+
FString Partial;
507+
508+
//Compute Partials
509+
if (ModelParams.Advanced.bEmitPartials)
510+
{
511+
bool bSplitFound = false;
512+
//Check new token for separators
513+
for (const FString& Separator : ModelParams.Advanced.PartialsSeparators)
514+
{
515+
if (Token.Contains(Separator))
516+
{
517+
bSplitFound = true;
518+
}
519+
}
520+
if (bSplitFound)
521+
{
522+
Partial = FLlamaString::GetLastSentence(CurrentReplyText);
523+
}
524+
}
525+
526+
//Emit token to game thread
527+
if (OnTokenGenerated)
528+
{
529+
OnTokenGenerated(Token);
530+
531+
if (OnPartialGenerated && !Partial.IsEmpty())
532+
{
533+
OnPartialGenerated(Partial);
534+
}
535+
}
536+
537+
//full response reply on finish
538+
if (bEoS)
539+
{
540+
if (OnModelStateChanged)
541+
{
542+
OnModelStateChanged(ModelState);
543+
}
544+
if (OnResponseGenerated)
545+
{
546+
OnResponseGenerated(CurrentReplyText);
547+
}
548+
}
549+
}
550+
410551
void FLlamaNative::RemoveLastNMessages(int32 MessageCount)
411552
{
412553
EnqueueBGTask([this, MessageCount](int64 TaskId)

0 commit comments

Comments
 (0)