Skip to content

Commit 9fd6af1

Browse files
authored
feat: implement OutputGuardrails for LLM response validation (#73)
1 parent 4000635 commit 9fd6af1

File tree

3 files changed

+112
-16
lines changed

3 files changed

+112
-16
lines changed

agent.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ type Agent struct {
6060
CodeChain *chain.UtcpChainClient
6161

6262
AllowUnsafeTools bool
63+
Guardrails *OutputGuardrails
6364
}
6465

6566
// Options configure a new Agent.
@@ -77,6 +78,7 @@ type Options struct {
7778
Shared *memory.SharedSession
7879
CodeChain *chain.UtcpChainClient
7980
AllowUnsafeTools bool
81+
Guardrails *OutputGuardrails
8082
}
8183

8284
// New creates an Agent with the provided options.
@@ -146,6 +148,7 @@ func New(opts Options) (*Agent, error) {
146148
CodeMode: opts.CodeMode,
147149
CodeChain: opts.CodeChain,
148150
AllowUnsafeTools: opts.AllowUnsafeTools,
151+
Guardrails: opts.Guardrails,
149152
}
150153

151154
return a, nil
@@ -1214,7 +1217,17 @@ func (a *Agent) Generate(ctx context.Context, sessionID, userInput string) (any,
12141217
// 3. Chain Orchestrator (LLM decides a multi-step chain execution)
12151218
// -------------------------------------------------------------
12161219
if handled, output, err := a.codeChainOrchestrator(ctx, sessionID, userInput); handled {
1217-
return output, err
1220+
if err != nil {
1221+
return "", err
1222+
}
1223+
if a.Guardrails != nil {
1224+
validated, gErr := a.Guardrails.ValidateAndRepair(ctx, output)
1225+
if gErr != nil {
1226+
return "", gErr
1227+
}
1228+
return validated, nil
1229+
}
1230+
return output, nil
12181231
}
12191232
// ---------------------------------------------
12201233
// 4. TOOL ORCHESTRATOR (normal UTCP tools)
@@ -1270,7 +1283,17 @@ func (a *Agent) Generate(ctx context.Context, sessionID, userInput string) (any,
12701283
return "", err
12711284
}
12721285

1273-
a.storeMemory(sessionID, "assistant", fmt.Sprintf("%s", completion), nil)
1286+
finalText := fmt.Sprint(completion)
1287+
if a.Guardrails != nil {
1288+
validatedText, gErr := a.Guardrails.ValidateAndRepair(ctx, finalText)
1289+
if gErr != nil {
1290+
return "", gErr
1291+
}
1292+
finalText = validatedText
1293+
completion = finalText
1294+
}
1295+
1296+
a.storeMemory(sessionID, "assistant", finalText, nil)
12741297
return completion, nil
12751298
}
12761299

agent_stream.go

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ func (a *Agent) GenerateStream(ctx context.Context, sessionID, userInput string)
7373

7474
// 3. Chain Orchestrator
7575
if handled, output, err := a.codeChainOrchestrator(ctx, sessionID, userInput); handled {
76+
if err == nil && a.Guardrails != nil {
77+
validated, gErr := a.Guardrails.ValidateAndRepair(ctx, output)
78+
if gErr != nil {
79+
return immediateStream("", gErr)
80+
}
81+
output = validated
82+
}
7683
return immediateStream(output, err)
7784
}
7885

@@ -110,23 +117,51 @@ func (a *Agent) GenerateStream(ctx context.Context, sessionID, userInput string)
110117

111118
// Wrap the stream to intercept and store memory
112119
outCh := make(chan models.StreamChunk)
113-
go func() {
114-
defer close(outCh)
115-
var full strings.Builder
116-
for chunk := range stream {
117-
if chunk.Err != nil {
118-
outCh <- chunk
120+
121+
if a.Guardrails != nil {
122+
go func() {
123+
defer close(outCh)
124+
var full strings.Builder
125+
for chunk := range stream {
126+
if chunk.Err != nil {
127+
outCh <- chunk
128+
return
129+
}
130+
if chunk.Delta != "" {
131+
full.WriteString(chunk.Delta)
132+
}
133+
}
134+
135+
finalText := full.String()
136+
validatedText, gErr := a.Guardrails.ValidateAndRepair(ctx, finalText)
137+
if gErr != nil {
138+
outCh <- models.StreamChunk{Err: gErr, Done: true}
119139
return
120140
}
121-
if chunk.Delta != "" {
122-
full.WriteString(chunk.Delta)
141+
142+
// Stream out the validated text as one chunk
143+
outCh <- models.StreamChunk{Delta: validatedText, FullText: validatedText, Done: true}
144+
a.storeMemory(sessionID, "assistant", validatedText, nil)
145+
}()
146+
} else {
147+
go func() {
148+
defer close(outCh)
149+
var full strings.Builder
150+
for chunk := range stream {
151+
if chunk.Err != nil {
152+
outCh <- chunk
153+
return
154+
}
155+
if chunk.Delta != "" {
156+
full.WriteString(chunk.Delta)
157+
}
158+
outCh <- chunk
123159
}
124-
outCh <- chunk
125-
}
126-
// Store memory after completion
127-
finalText := full.String()
128-
a.storeMemory(sessionID, "assistant", finalText, nil)
129-
}()
160+
// Store memory after completion
161+
finalText := full.String()
162+
a.storeMemory(sessionID, "assistant", finalText, nil)
163+
}()
164+
}
130165

131166
return outCh, nil
132167
}

types.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,41 @@ type AgentState struct {
6262
JoinedSpaces []string `json:"joined_spaces,omitempty"`
6363
Timestamp time.Time `json:"timestamp"`
6464
}
65+
66+
// SafetyPolicy defines an interface for validating LLM responses.
67+
type SafetyPolicy interface {
68+
Validate(ctx context.Context, response string) error
69+
}
70+
71+
// FormatEnforcer defines an interface for validating or repairing the format of LLM responses.
72+
type FormatEnforcer interface {
73+
Enforce(ctx context.Context, response string) (string, error)
74+
}
75+
76+
// OutputGuardrails holds the policy engines and formatting rules.
77+
type OutputGuardrails struct {
78+
SafetyPolicies []SafetyPolicy
79+
FormatEnforcers []FormatEnforcer
80+
}
81+
82+
// ValidateAndRepair applies safety checks and format enforcing to the response.
83+
func (g *OutputGuardrails) ValidateAndRepair(ctx context.Context, response string) (string, error) {
84+
if g == nil {
85+
return response, nil
86+
}
87+
for _, policy := range g.SafetyPolicies {
88+
if err := policy.Validate(ctx, response); err != nil {
89+
return "", err
90+
}
91+
}
92+
93+
final := response
94+
var err error
95+
for _, enforcer := range g.FormatEnforcers {
96+
final, err = enforcer.Enforce(ctx, final)
97+
if err != nil {
98+
return "", err
99+
}
100+
}
101+
return final, nil
102+
}

0 commit comments

Comments
 (0)