-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathinference-service.go
More file actions
254 lines (222 loc) · 6.49 KB
/
inference-service.go
File metadata and controls
254 lines (222 loc) · 6.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
package main
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"os"
"os/exec"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
)
// InferenceCompletionRequest represents the structure for inference completion requests
type InferenceCompletionRequest struct {
LlamaCliArgs LlamaCliArgs `json:"llamaCliArgs"`
PromptText string `json:"promptText"`
PromptType string `json:"promptType"`
RequestID string `json:"requestId,omitempty"`
}
// InferenceCompletionResponse represents the response structure
type InferenceCompletionResponse struct {
RequestID string `json:"requestId,omitempty"`
Success bool `json:"success"`
Result string `json:"result"`
Error string `json:"error,omitempty"`
ProcessingTime int64 `json:"processingTime"`
}
// InferenceCompletionProgress represents progress updates
type InferenceCompletionProgress struct {
RequestID string `json:"requestId,omitempty"`
Status string `json:"status"`
Message string `json:"message"`
Progress int `json:"progress"` // 0-100
}
// isOperationCanceled checks if the current operation was canceled
func (app *App) isOperationCanceled() bool {
if app.operationCtx == nil {
app.log.Error("operationCtx is nil - this indicates improper initialization")
return false // or return true if you want to treat nil context as cancelled
}
select {
case <-app.operationCtx.Done():
app.log.Info("Operation was cancelled by user")
return true
default:
return false
}
}
// saveQuestionResponse saves the generated completion to the database
func (app *App) saveQuestionResponse(llamaCliArgs LlamaCliArgs, completionOutput []byte, originalPromptText string) error {
jsonArgs, err := json.Marshal(llamaCliArgs)
if err != nil {
return fmt.Errorf("failed to convert arguments to JSON: %w", err)
}
// Convert JSON args to string and remove brackets
jsonArgsStr := app.removeJSONBrackets(string(jsonArgs))
return SaveQuestionResponse(app.appArgs, string(completionOutput), jsonArgsStr, originalPromptText)
}
// removeJSONBrackets removes the first and last characters (brackets) from JSON string
func (app *App) removeJSONBrackets(jsonStr string) string {
if len(jsonStr) > 2 {
return jsonStr[1 : len(jsonStr)-1]
}
return jsonStr
}
// generateUniqueFileName creates a unique filename with timestamp and random suffix
func generateUniqueFileName(baseName string) string {
timestamp := time.Now().Format("20060102_150405")
randomSuffix := rand.Intn(10000)
return fmt.Sprintf("%s_%s_%04d.txt", baseName, timestamp, randomSuffix)
}
var (
llamaCliMutex sync.Mutex
)
func GenerateSingleCompletionWithCancel(ctx context.Context, appArgs DefaultAppArgs, args []string) ([]byte, error) {
// Lock to prevent concurrent CLI calls
if ctx == nil {
ctx = context.Background()
}
llamaCliMutex.Lock()
defer llamaCliMutex.Unlock()
fmt.Println(args)
// Create the command with context
cmd := exec.CommandContext(ctx, appArgs.LLamaCliPath, args...)
// Set up process attributes for proper termination
if runtime.GOOS == "windows" {
cmd.SysProcAttr = &syscall.SysProcAttr{
HideWindow: true,
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
} else {
// On Unix-like systems, use process group settings
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
// Create a channel to capture the result
result := make(chan struct {
output []byte
err error
})
// Run the command in a goroutine
go func() {
defer close(result)
out, err := cmd.Output()
select {
case result <- struct {
output []byte
err error
}{output: out, err: err}:
case <-ctx.Done():
// Context canceled, don't send result
}
}()
select {
case res := <-result:
// Command completed normally
return res.output, res.err
case <-ctx.Done():
// Context was canceled - ensure process is terminated
if cmd.Process != nil {
// Use the standard Process.Kill() method which works cross-platform
if err := cmd.Process.Kill(); err != nil {
// If Kill() fails, try using Signal on Unix
if runtime.GOOS != "windows" {
cmd.Process.Signal(os.Kill)
}
}
}
// Wait for the process to actually terminate
if cmd.ProcessState == nil {
go func() {
err := cmd.Wait()
if err != nil {
return
} // Clean up the process
}()
}
return nil, ctx.Err()
}
}
func GenerateTokenCount(ctx context.Context, appArgs DefaultAppArgs, args []string) ([]byte, error) {
// Lock to prevent concurrent CLI calls
if ctx == nil {
ctx = context.Background()
}
llamaCliMutex.Lock()
defer llamaCliMutex.Unlock()
// Create the command with context
cmd := exec.CommandContext(ctx, appArgs.LLamaTokenCountCliPath, args...)
// Set up process attributes for proper termination
if runtime.GOOS == "windows" {
cmd.SysProcAttr = &syscall.SysProcAttr{
HideWindow: true,
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
} else {
// On Unix-like systems, use process group settings
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
// Create a channel to capture the result
result := make(chan struct {
output []byte
err error
})
// Run the command in a goroutine
go func() {
defer close(result)
out, err := cmd.Output()
select {
case result <- struct {
output []byte
err error
}{output: out, err: err}:
case <-ctx.Done():
// Context canceled, don't send result
}
}()
select {
case res := <-result:
// Command completed normally
return res.output, res.err
case <-ctx.Done():
// Context was canceled - ensure process is terminated
if cmd.Process != nil {
// Use the standard Process.Kill() method which works cross-platform
if err := cmd.Process.Kill(); err != nil {
// If Kill() fails, try using Signal on Unix
if runtime.GOOS != "windows" {
cmd.Process.Signal(os.Kill)
}
}
}
// Wait for the process to actually terminate
if cmd.ProcessState == nil {
go func() {
err := cmd.Wait()
if err != nil {
return
} // Clean up the process
}()
}
return nil, ctx.Err()
}
}
func ExtractTokenCount(output string) (int, error) {
re := regexp.MustCompile(`(?i)total\s+number\s+of\s+tokens:\s*([\d,]+)`)
matches := re.FindAllStringSubmatch(output, -1)
if len(matches) == 0 {
return 0, fmt.Errorf("token count not found in output")
}
// Take the last match in case the line appears multiple times
raw := matches[len(matches)-1][1]
raw = strings.ReplaceAll(raw, ",", "")
n, err := strconv.Atoi(raw)
if err != nil {
return 0, fmt.Errorf("invalid token count %q: %w", raw, err)
}
return n, nil
}