Skip to content

Commit 420a2b8

Browse files
committed
Implement impersonation feature for component
- `ImpersonateTemplatedPrompt` used for user input (or non-streamed generated assistant input) - `ImpersonateTemplatedToken` used for generated stream input. Signal EOS true to finish the response - System prompt default insertion moved to load model function so we can chain user prompts directly after it without waiting for results. Will only insert if load is successful - Log TGS if `bLogGenerationStats` is true - Bind OnResponse instead of callback method so that both local and remote flow through same api
1 parent 6de7e3a commit 420a2b8

File tree

6 files changed

+190
-24
lines changed

6 files changed

+190
-24
lines changed

Llama.uplugin

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"FileVersion": 3,
33
"Version": 1,
4-
"VersionName": "0.9.3",
4+
"VersionName": "0.9.4",
55
"FriendlyName": "Llama",
66
"Description": "Llama.cpp plugin for large language model (LLM) inference.",
77
"Category": "LLM",

Source/LlamaCore/Private/LlamaComponent.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ ULlamaComponent::ULlamaComponent(const FObjectInitializer &ObjectInitializer)
1919
OnTokenGenerated.Broadcast(Token);
2020
};
2121

22+
LlamaNative->OnResponseGenerated = [this](const FString& Response)
23+
{
24+
OnResponseGenerated.Broadcast(Response);
25+
OnEndOfStream.Broadcast(true, ModelState.LastTokenGenerationSpeed);
26+
};
27+
2228
LlamaNative->OnPartialGenerated = [this](const FString& Partial)
2329
{
2430
OnPartialGenerated.Broadcast(Partial);
@@ -87,26 +93,26 @@ void ULlamaComponent::InsertTemplatedPrompt(const FString& Text, EChatTemplateRo
8793

8894
void ULlamaComponent::InsertTemplatedPromptStruct(const FLlamaChatPrompt& ChatPrompt)
8995
{
90-
LlamaNative->InsertTemplatedPrompt(ChatPrompt, [this, ChatPrompt](const FString& Response)
91-
{
96+
LlamaNative->InsertTemplatedPrompt(ChatPrompt);/*, [this, ChatPrompt](const FString& Response));
97+
{
9298
if (ChatPrompt.bGenerateReply)
9399
{
94100
OnResponseGenerated.Broadcast(Response);
95101
OnEndOfStream.Broadcast(true, ModelState.LastTokenGenerationSpeed);
96102
}
97-
});
103+
});*/
98104
}
99105

100106
void ULlamaComponent::InsertRawPrompt(const FString& Text, bool bGenerateReply)
101107
{
102-
LlamaNative->InsertRawPrompt(Text, bGenerateReply, [this, bGenerateReply](const FString& Response)
108+
LlamaNative->InsertRawPrompt(Text, bGenerateReply); /*, [this, bGenerateReply](const FString& Response)
103109
{
104110
if (bGenerateReply)
105111
{
106112
OnResponseGenerated.Broadcast(Response);
107113
OnEndOfStream.Broadcast(true, ModelState.LastTokenGenerationSpeed);
108114
}
109-
});
115+
});*/
110116
}
111117

112118
void ULlamaComponent::LoadModel(bool bForceReload)
@@ -120,11 +126,6 @@ void ULlamaComponent::LoadModel(bool bForceReload)
120126
return;
121127
}
122128

123-
if (ModelParams.bAutoInsertSystemPromptOnLoad)
124-
{
125-
InsertTemplatedPrompt(ModelParams.SystemPrompt, EChatTemplateRole::System, false, false);
126-
}
127-
128129
OnModelLoaded.Broadcast(ModelPath);
129130
});
130131
}
@@ -163,6 +164,19 @@ void ULlamaComponent::RemoveLastUserInput()
163164
LlamaNative->RemoveLastUserInput();
164165
}
165166

167+
168+
void ULlamaComponent::ImpersonateTemplatedPrompt(const FLlamaChatPrompt& ChatPrompt)
169+
{
170+
LlamaNative->SetModelParams(ModelParams);
171+
172+
LlamaNative->ImpersonateTemplatedPrompt(ChatPrompt);
173+
}
174+
175+
void ULlamaComponent::ImpersonateTemplatedToken(const FString& Token, EChatTemplateRole Role, bool bEoS)
176+
{
177+
LlamaNative->ImpersonateTemplatedToken(Token, Role, bEoS);
178+
}
179+
166180
FString ULlamaComponent::WrapPromptForRole(const FString& Text, EChatTemplateRole Role, const FString& Template)
167181
{
168182
return LlamaNative->WrapPromptForRole(Text, Role, Template);

Source/LlamaCore/Private/LlamaNative.cpp

Lines changed: 145 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ FLlamaNative::FLlamaNative()
6060
{
6161
if (ModelParams.Advanced.bLogGenerationStats)
6262
{
63-
UE_LOG(LlamaLog, Log, TEXT("Generated %d tokens in %1.2fs (%1.2ftps)"), TokensGenerated, Duration, SpeedTps);
63+
UE_LOG(LlamaLog, Log, TEXT("TGS - Generated %d tokens in %1.2fs (%1.2ftps)"), TokensGenerated, Duration, SpeedTps);
6464
}
6565

6666
int32 UsedContext = UsedContextLength();
@@ -88,6 +88,11 @@ FLlamaNative::FLlamaNative()
8888

8989
Internal->OnPromptProcessed = [this](int32 TokensProcessed, EChatTemplateRole RoleProcessed, float SpeedTps)
9090
{
91+
if (ModelParams.Advanced.bLogGenerationStats)
92+
{
93+
UE_LOG(LlamaLog, Log, TEXT("PPS - Processed %d tokens at %1.2ftps"), TokensProcessed, SpeedTps);
94+
}
95+
9196
int32 UsedContext = UsedContextLength();
9297

9398
//Sync history data with additional state updates
@@ -266,6 +271,7 @@ void FLlamaNative::LoadModel(bool bForceReload, TFunction<void(const FString&, i
266271
//already loaded, we're done
267272
return ModelLoadedCallback(ModelParams.PathToModel, 0);
268273
}
274+
bModelLoadInitiated = true;
269275

270276
//Copy so these dont get modified during enqueue op
271277
const FLLMModelParams ParamsAtLoad = ModelParams;
@@ -284,6 +290,14 @@ void FLlamaNative::LoadModel(bool bForceReload, TFunction<void(const FString&, i
284290
const FString TemplateString = FLlamaString::ToUE(Internal->Template);
285291
const FString TemplateSource = FLlamaString::ToUE(Internal->TemplateSource);
286292

293+
//Before we release the BG thread, ensure we enqueue the system prompt
294+
//If we do it later, other queued calls will frontrun it. This enables startup chaining correctly
295+
if (ParamsAtLoad.bAutoInsertSystemPromptOnLoad)
296+
{
297+
Internal->InsertTemplatedPrompt(FLlamaString::ToStd(ParamsAtLoad.SystemPrompt), EChatTemplateRole::System, false, false);
298+
}
299+
300+
//Callback on game thread for data sync
287301
EnqueueGTTask([this, TemplateString, TemplateSource, ModelLoadedCallback]
288302
{
289303
FJinjaChatTemplate ChatTemplate;
@@ -293,6 +307,8 @@ void FLlamaNative::LoadModel(bool bForceReload, TFunction<void(const FString&, i
293307
ModelState.ChatTemplateInUse = ChatTemplate;
294308
ModelState.bModelIsLoaded = true;
295309

310+
bModelLoadInitiated = false;
311+
296312
if (OnModelStateChanged)
297313
{
298314
OnModelStateChanged(ModelState);
@@ -308,15 +324,22 @@ void FLlamaNative::LoadModel(bool bForceReload, TFunction<void(const FString&, i
308324
{
309325
EnqueueGTTask([this, ModelLoadedCallback]
310326
{
327+
bModelLoadInitiated = false;
328+
311329
//On error will be triggered earlier in the chain, but forward our model loading error status here
312-
ModelLoadedCallback(ModelParams.PathToModel, 15);
330+
if (ModelLoadedCallback)
331+
{
332+
ModelLoadedCallback(ModelParams.PathToModel, 15);
333+
}
313334
}, TaskId);
314335
}
315336
});
316337
}
317338

318339
void FLlamaNative::UnloadModel(TFunction<void(int32 StatusCode)> ModelUnloadedCallback)
319340
{
341+
bModelLoadInitiated = false;
342+
320343
EnqueueBGTask([this, ModelUnloadedCallback](int64 TaskId)
321344
{
322345
if (IsModelLoaded())
@@ -349,7 +372,7 @@ bool FLlamaNative::IsModelLoaded()
349372

350373
void FLlamaNative::InsertTemplatedPrompt(const FLlamaChatPrompt& Prompt, TFunction<void(const FString& Response)> OnResponseFinished)
351374
{
352-
if (!IsModelLoaded())
375+
if (!IsModelLoaded() && !bModelLoadInitiated)
353376
{
354377
UE_LOG(LlamaLog, Warning, TEXT("Model isn't loaded, can't run prompt."));
355378
return;
@@ -386,7 +409,7 @@ void FLlamaNative::InsertTemplatedPrompt(const FLlamaChatPrompt& Prompt, TFuncti
386409

387410
void FLlamaNative::InsertRawPrompt(const FString& Prompt, bool bGenerateReply, TFunction<void(const FString& Response)>OnResponseFinished)
388411
{
389-
if (!IsModelLoaded())
412+
if (!IsModelLoaded() && !bModelLoadInitiated)
390413
{
391414
UE_LOG(LlamaLog, Warning, TEXT("Model isn't loaded, can't run prompt."));
392415
return;
@@ -407,6 +430,124 @@ void FLlamaNative::InsertRawPrompt(const FString& Prompt, bool bGenerateReply, T
407430
});
408431
}
409432

433+
void FLlamaNative::ImpersonateTemplatedPrompt(const FLlamaChatPrompt& Prompt)
434+
{
435+
//modify model state
436+
if (IsModelLoaded())
437+
{
438+
//insert it but make sure we don't do any token generation
439+
FLlamaChatPrompt ModifiedPrompt = Prompt;
440+
ModifiedPrompt.bGenerateReply = false;
441+
442+
InsertTemplatedPrompt(ModifiedPrompt);
443+
}
444+
else
445+
{
446+
//no model, so just run this in sync mode
447+
FStructuredChatMessage Message;
448+
Message.Role = Prompt.Role;
449+
Message.Content = Prompt.Prompt;
450+
451+
//modify our chat history state
452+
ModelState.ChatHistory.History.Add(Message);
453+
454+
if (OnModelStateChanged)
455+
{
456+
OnModelStateChanged(ModelState);
457+
}
458+
//was this an assistant message? emit response generated callback
459+
if (Message.Role == EChatTemplateRole::Assistant)
460+
{
461+
if (OnResponseGenerated)
462+
{
463+
OnResponseGenerated(Prompt.Prompt);
464+
}
465+
}
466+
}
467+
}
468+
469+
void FLlamaNative::ImpersonateTemplatedToken(const FString& Token, EChatTemplateRole Role, bool bEoS)
470+
{
471+
//Should be called on game thread.
472+
473+
//NB: we don't support updating model internal state atm
474+
475+
//Check if we need to add a message before modifying it
476+
bool bLastRoleWasMatchingRole = false;
477+
478+
if (ModelState.ChatHistory.History.Num() != 0)
479+
{
480+
FStructuredChatMessage& Message = ModelState.ChatHistory.History.Last();
481+
bLastRoleWasMatchingRole = Message.Role == Role;
482+
}
483+
484+
FString CurrentReplyText;
485+
486+
if (!bLastRoleWasMatchingRole)
487+
{
488+
FStructuredChatMessage Message;
489+
Message.Role = Role;
490+
Message.Content = Token;
491+
492+
ModelState.ChatHistory.History.Add(Message);
493+
494+
CurrentReplyText += Token;
495+
}
496+
else
497+
{
498+
FStructuredChatMessage& Message = ModelState.ChatHistory.History.Last();
499+
Message.Content += Token;
500+
501+
CurrentReplyText += Message.Content;
502+
}
503+
504+
FStructuredChatMessage& Message = ModelState.ChatHistory.History.Last();
505+
506+
FString Partial;
507+
508+
//Compute Partials
509+
if (ModelParams.Advanced.bEmitPartials)
510+
{
511+
bool bSplitFound = false;
512+
//Check new token for separators
513+
for (const FString& Separator : ModelParams.Advanced.PartialsSeparators)
514+
{
515+
if (Token.Contains(Separator))
516+
{
517+
bSplitFound = true;
518+
}
519+
}
520+
if (bSplitFound)
521+
{
522+
Partial = FLlamaString::GetLastSentence(CurrentReplyText);
523+
}
524+
}
525+
526+
//Emit token to game thread
527+
if (OnTokenGenerated)
528+
{
529+
OnTokenGenerated(Token);
530+
531+
if (OnPartialGenerated && !Partial.IsEmpty())
532+
{
533+
OnPartialGenerated(Partial);
534+
}
535+
}
536+
537+
//full response reply on finish
538+
if (bEoS)
539+
{
540+
if (OnModelStateChanged)
541+
{
542+
OnModelStateChanged(ModelState);
543+
}
544+
if (OnResponseGenerated)
545+
{
546+
OnResponseGenerated(CurrentReplyText);
547+
}
548+
}
549+
}
550+
410551
void FLlamaNative::RemoveLastNMessages(int32 MessageCount)
411552
{
412553
EnqueueBGTask([this, MessageCount](int64 TaskId)

Source/LlamaCore/Private/LlamaSubsystem.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ void ULlamaSubsystem::Initialize(FSubsystemCollectionBase& Collection)
2929
{
3030
OnPromptProcessed.Broadcast(TokensProcessed, Role, Speed);
3131
};
32+
LlamaNative->OnResponseGenerated = [this](const FString& Response)
33+
{
34+
OnResponseGenerated.Broadcast(Response);
35+
OnEndOfStream.Broadcast(true, ModelState.LastTokenGenerationSpeed);
36+
};
3237
LlamaNative->OnError = [this](const FString& ErrorMessage, int32 ErrorCode)
3338
{
3439
OnError.Broadcast(ErrorMessage, ErrorCode);
@@ -63,26 +68,26 @@ void ULlamaSubsystem::InsertTemplatedPrompt(const FString& Text, EChatTemplateRo
6368

6469
void ULlamaSubsystem::InsertTemplatedPromptStruct(const FLlamaChatPrompt& ChatPrompt)
6570
{
66-
LlamaNative->InsertTemplatedPrompt(ChatPrompt, [this, ChatPrompt](const FString& Response)
71+
LlamaNative->InsertTemplatedPrompt(ChatPrompt);/*, [this, ChatPrompt](const FString& Response)
6772
{
6873
if (ChatPrompt.bGenerateReply)
6974
{
7075
OnResponseGenerated.Broadcast(Response);
7176
OnEndOfStream.Broadcast(true, ModelState.LastTokenGenerationSpeed);
7277
}
73-
});
78+
});*/
7479
}
7580

7681
void ULlamaSubsystem::InsertRawPrompt(const FString& Text, bool bGenerateReply)
7782
{
78-
LlamaNative->InsertRawPrompt(Text, bGenerateReply, [this, bGenerateReply](const FString& Response)
83+
LlamaNative->InsertRawPrompt(Text, bGenerateReply);/*, [this, bGenerateReply](const FString& Response)
7984
{
8085
if (bGenerateReply)
8186
{
8287
OnResponseGenerated.Broadcast(Response);
8388
OnEndOfStream.Broadcast(true, ModelState.LastTokenGenerationSpeed);
8489
}
85-
});
90+
})*/;
8691
}
8792

8893
void ULlamaSubsystem::LoadModel(bool bForceReload)
@@ -104,11 +109,6 @@ void ULlamaSubsystem::LoadModel(bool bForceReload)
104109
return;
105110
}
106111

107-
if (ModelParams.bAutoInsertSystemPromptOnLoad)
108-
{
109-
InsertTemplatedPrompt(ModelParams.SystemPrompt, EChatTemplateRole::System, false, false);
110-
}
111-
112112
OnModelLoaded.Broadcast(ModelPath);
113113
});
114114
}

Source/LlamaCore/Public/LlamaComponent.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ class LLAMACORE_API ULlamaComponent : public UActorComponent
107107
UFUNCTION(BlueprintCallable, Category = "LLM Model Component")
108108
void InsertRawPrompt(UPARAM(meta = (MultiLine = true)) const FString& Text, bool bGenerateReply = true);
109109

110+
//Typically as user, this pretends the input was generated in history and all downstream functions should trigger. KV-cache won't be updated if no models are loaded.
111+
UFUNCTION(BlueprintCallable, Category = "LLM Model Component - Impersonation via External API")
112+
void ImpersonateTemplatedPrompt(const FLlamaChatPrompt& ChatPrompt);
113+
114+
//Use this to feed external model inference through our loop (e.g. as assistant tokens are generated), it will pretend the output was generated locally downstream.
115+
UFUNCTION(BlueprintCallable, Category = "LLM Model Component - Impersonation via External API")
116+
void ImpersonateTemplatedToken(const FString& Token, EChatTemplateRole Role = EChatTemplateRole::Assistant, bool bIsEndOfStream = false);
117+
110118
//if you want to manually wrap prompt, if template is empty string, default model template is applied. NB: this function should be thread safe, but this has not be thoroughly tested.
111119
UFUNCTION(BlueprintPure, Category = "LLM Model Component")
112120
FString WrapPromptForRole(const FString& Text, EChatTemplateRole Role, const FString& OverrideTemplate);

Source/LlamaCore/Public/LlamaNative.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class LLAMACORE_API FLlamaNative
3737
TFunction<void(const FString& Response)>OnResponseFinished = nullptr);
3838
void InsertRawPrompt(const FString& Prompt, bool bGenerateReply = true,
3939
TFunction<void(const FString& Response)>OnResponseFinished = nullptr);
40+
void ImpersonateTemplatedPrompt(const FLlamaChatPrompt& Prompt);
41+
void ImpersonateTemplatedToken(const FString& Token, EChatTemplateRole Role = EChatTemplateRole::Assistant, bool bEoS = false);
4042
bool IsGenerating();
4143
void StopGeneration();
4244
void ResumeGeneration();
@@ -82,6 +84,7 @@ class LLAMACORE_API FLlamaNative
8284
//GT State - safely accesible on game thread
8385
FLLMModelParams ModelParams;
8486
FLLMModelState ModelState;
87+
bool bModelLoadInitiated = false; //tracking model load attempts
8588

8689
//BG State - do not read/write on GT
8790
FString CombinedPieceText; //accumulates tokens into full string during per-token inference.

0 commit comments

Comments
 (0)