-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathaddon_invocation_worker.go
More file actions
278 lines (238 loc) · 9.23 KB
/
addon_invocation_worker.go
File metadata and controls
278 lines (238 loc) · 9.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
package api
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/ericfitz/tmi/internal/slogging"
"github.com/google/uuid"
)
// AddonInvocationWorker handles delivery of add-on invocations to webhooks
type AddonInvocationWorker struct {
httpClient *http.Client
running bool
stopChan chan struct{}
workChan chan uuid.UUID // Channel for invocation IDs to process
baseURL string // Server base URL for callback URLs
}
// AddonInvocationPayload represents the payload sent to webhook endpoints
type AddonInvocationPayload struct {
EventType string `json:"event_type"`
InvocationID uuid.UUID `json:"invocation_id"`
AddonID uuid.UUID `json:"addon_id"`
ThreatModelID uuid.UUID `json:"threat_model_id"`
ObjectType string `json:"object_type,omitempty"`
ObjectID *uuid.UUID `json:"object_id,omitempty"`
Timestamp time.Time `json:"timestamp"`
Payload json.RawMessage `json:"payload"`
CallbackURL string `json:"callback_url"`
}
// NewAddonInvocationWorker creates a new invocation worker
func NewAddonInvocationWorker() *AddonInvocationWorker {
return &AddonInvocationWorker{
httpClient: &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Don't follow redirects
},
},
stopChan: make(chan struct{}),
workChan: make(chan uuid.UUID, 100), // Buffer up to 100 pending invocations
baseURL: "http://localhost:8080", // Default, should be set via SetBaseURL
}
}
// SetBaseURL sets the server's base URL for callback URLs
func (w *AddonInvocationWorker) SetBaseURL(baseURL string) {
w.baseURL = baseURL
}
// Start begins processing invocations
func (w *AddonInvocationWorker) Start(ctx context.Context) error {
logger := slogging.Get()
w.running = true
logger.Info("addon invocation worker started")
// Start processing in a goroutine
go w.processLoop(ctx)
return nil
}
// Stop gracefully stops the worker
func (w *AddonInvocationWorker) Stop() {
logger := slogging.Get()
if w.running {
w.running = false
close(w.stopChan)
logger.Info("addon invocation worker stopped")
}
}
// QueueInvocation queues an invocation for processing
func (w *AddonInvocationWorker) QueueInvocation(invocationID uuid.UUID) {
select {
case w.workChan <- invocationID:
// Successfully queued
default:
logger := slogging.Get()
logger.Warn("addon invocation worker queue full, dropping invocation: %s", invocationID)
}
}
// processLoop continuously processes invocations from the work queue
func (w *AddonInvocationWorker) processLoop(ctx context.Context) {
logger := slogging.Get()
for w.running {
select {
case <-ctx.Done():
logger.Info("context cancelled, stopping invocation worker")
return
case <-w.stopChan:
logger.Info("stop signal received, stopping invocation worker")
return
case invocationID := <-w.workChan:
if err := w.processInvocation(ctx, invocationID); err != nil {
logger.Error("error processing invocation %s: %v", invocationID, err)
}
}
}
}
// processInvocation processes a single invocation
func (w *AddonInvocationWorker) processInvocation(ctx context.Context, invocationID uuid.UUID) error {
logger := slogging.Get()
// Get invocation
invocation, err := GlobalAddonInvocationStore.Get(ctx, invocationID)
if err != nil {
logger.Error("failed to get invocation %s: %v", invocationID, err)
return err
}
// Get add-on details
addon, err := GlobalAddonStore.Get(ctx, invocation.AddonID)
if err != nil {
logger.Error("failed to get add-on %s: %v", invocation.AddonID, err)
return err
}
// Get webhook subscription details
webhook, err := GlobalWebhookSubscriptionStore.Get(addon.WebhookID.String())
if err != nil {
logger.Error("failed to get webhook %s: %v", addon.WebhookID, err)
// Mark invocation as failed
invocation.Status = InvocationStatusFailed
invocation.StatusMessage = fmt.Sprintf("Webhook not found: %v", err)
_ = GlobalAddonInvocationStore.Update(ctx, invocation)
return err
}
// Check if webhook is active
if webhook.Status != "active" {
logger.Warn("webhook %s is not active (status: %s), failing invocation", webhook.Id, webhook.Status)
invocation.Status = InvocationStatusFailed
invocation.StatusMessage = fmt.Sprintf("Webhook not active (status: %s)", webhook.Status)
_ = GlobalAddonInvocationStore.Update(ctx, invocation)
return nil
}
logger.Debug("sending addon invocation to %s (invocation: %s)", webhook.Url, invocationID)
// Build callback URL using configured base URL
callbackURL := fmt.Sprintf("%s/invocations/%s/status", w.baseURL, invocationID)
// Build payload
payload := AddonInvocationPayload{
EventType: "addon.invoked",
InvocationID: invocation.ID,
AddonID: invocation.AddonID,
ThreatModelID: invocation.ThreatModelID,
ObjectType: invocation.ObjectType,
ObjectID: invocation.ObjectID,
Timestamp: invocation.CreatedAt,
Payload: json.RawMessage(invocation.Payload),
CallbackURL: callbackURL,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
logger.Error("failed to marshal invocation payload: %v", err)
invocation.Status = InvocationStatusFailed
invocation.StatusMessage = fmt.Sprintf("Failed to marshal payload: %v", err)
_ = GlobalAddonInvocationStore.Update(ctx, invocation)
return err
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", webhook.Url, bytes.NewReader(payloadBytes))
if err != nil {
invocation.Status = InvocationStatusFailed
invocation.StatusMessage = fmt.Sprintf("Failed to create request: %v", err)
_ = GlobalAddonInvocationStore.Update(ctx, invocation)
return err
}
// Add headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Webhook-Event", "addon.invoked")
req.Header.Set("X-Invocation-Id", invocationID.String())
req.Header.Set("X-Addon-Id", invocation.AddonID.String())
req.Header.Set("User-Agent", "TMI-Addon-Worker/1.0")
// Add HMAC signature
if webhook.Secret != "" {
signature := w.generateSignature(payloadBytes, webhook.Secret)
req.Header.Set("X-Webhook-Signature", signature)
}
// Send request (no retries for now - webhook can call back with failures)
resp, err := w.httpClient.Do(req)
if err != nil {
logger.Error("addon invocation request failed for %s: %v", invocationID, err)
invocation.Status = InvocationStatusFailed
invocation.StatusMessage = fmt.Sprintf("Request failed: %v", err)
_ = GlobalAddonInvocationStore.Update(ctx, invocation)
return err
}
defer func() { _ = resp.Body.Close() }()
// Read response (limit to 10KB for logging)
body, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024))
// Check response status
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
logger.Info("addon invocation sent successfully to %s (invocation: %s, status: %d)",
webhook.Url, invocationID, resp.StatusCode)
// Check if webhook wants to use async callbacks
// If X-TMI-Callback: async is set, the webhook will call back with status updates
// Otherwise, auto-complete the invocation (webhook handles work internally)
callbackMode := resp.Header.Get("X-TMI-Callback")
if callbackMode == "async" {
// Webhook will call back with status updates
invocation.Status = InvocationStatusInProgress
invocation.StatusMessage = "Invocation sent to webhook, awaiting callback"
logger.Debug("webhook requested async callback mode for invocation %s", invocationID)
} else {
// Auto-complete: webhook accepted and will handle internally
invocation.Status = InvocationStatusCompleted
invocation.StatusMessage = "Invocation delivered successfully"
invocation.StatusPercent = 100
logger.Debug("auto-completing invocation %s (no async callback requested)", invocationID)
}
if err := GlobalAddonInvocationStore.Update(ctx, invocation); err != nil {
logger.Error("failed to update invocation status: %v", err)
}
return nil
}
// Request failed
errorMsg := fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(body))
logger.Error("addon invocation failed for %s: %s", invocationID, errorMsg)
invocation.Status = InvocationStatusFailed
invocation.StatusMessage = errorMsg
_ = GlobalAddonInvocationStore.Update(ctx, invocation)
return fmt.Errorf("invocation failed: %s", errorMsg)
}
// generateSignature generates HMAC-SHA256 signature for the payload
func (w *AddonInvocationWorker) generateSignature(payload []byte, secret string) string {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(payload)
return "sha256=" + hex.EncodeToString(mac.Sum(nil))
}
// VerifySignature verifies the HMAC signature of a request
func VerifySignature(payload []byte, signature string, secret string) bool {
expectedSignature := generateHMACSignature(payload, secret)
return hmac.Equal([]byte(signature), []byte(expectedSignature))
}
// generateHMACSignature generates an HMAC signature (helper for verification)
func generateHMACSignature(payload []byte, secret string) string {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(payload)
return "sha256=" + hex.EncodeToString(mac.Sum(nil))
}
// GlobalAddonInvocationWorker is the global singleton for the invocation worker
var GlobalAddonInvocationWorker *AddonInvocationWorker