Skip to content

Commit 39dbf3c

Browse files
committed
Add any served models in a column in the jobs table
Enables the user to leave and return to the model chat eval without having launch a new job, instead use the existing job for either pre-train or post-train. Signed-off-by: Brent Salisbury <[email protected]>
1 parent 94ddd85 commit 39dbf3c

File tree

4 files changed

+163
-33
lines changed

4 files changed

+163
-33
lines changed

api-server/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ go.work.sum
2121
# env file
2222
.env
2323

24+
# binary
25+
api-server
26+
2427
# app specific
2528
logs/
2629
jobs.json

api-server/handlers.go

Lines changed: 96 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"bytes"
5+
"database/sql"
56
"encoding/json"
67
"fmt"
78
"github.com/gorilla/mux"
@@ -243,12 +244,21 @@ func (srv *ILabServer) getVllmStatusHandler(w http.ResponseWriter, r *http.Reque
243244
return
244245
}
245246

246-
srv.jobIDsMutex.RLock()
247-
jobID, ok := srv.servedModelJobIDs[modelName]
248-
srv.jobIDsMutex.RUnlock()
247+
// Directly query the DB for the job associated with this model
248+
var jobID string
249+
err = srv.db.QueryRow(`
250+
SELECT job_id
251+
FROM jobs
252+
WHERE served_model_name = ? AND status = 'running'
253+
LIMIT 1
254+
`, modelName).Scan(&jobID)
249255

250-
if !ok {
251-
srv.log.Infof("WTF jobid not found for model '%s'", modelName)
256+
if err == sql.ErrNoRows {
257+
srv.log.Infof("No running job found for model '%s'", modelName)
258+
_ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"})
259+
return
260+
} else if err != nil {
261+
srv.log.Errorf("Error querying job for model '%s': %v", modelName, err)
252262
_ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"})
253263
return
254264
}
@@ -629,6 +639,26 @@ func (srv *ILabServer) runVllmContainerHandler(
629639
gpuIndex int, hostVolume, containerVolume string,
630640
w http.ResponseWriter,
631641
) {
642+
// Check if a job is already running for the requested model
643+
existingJob, err := srv.getRunningJobByModel(servedModelName)
644+
if err != nil {
645+
srv.log.Errorf("Error checking existing jobs for model '%s': %v", servedModelName, err)
646+
http.Error(w, "Internal server error", http.StatusInternalServerError)
647+
return
648+
}
649+
if existingJob != nil {
650+
srv.log.Infof("A job is already running for model '%s' with job_id: %s", servedModelName, existingJob.JobID)
651+
w.Header().Set("Content-Type", "application/json")
652+
_ = json.NewEncoder(w).Encode(map[string]string{
653+
"status": "already_running",
654+
"job_id": existingJob.JobID,
655+
"message": fmt.Sprintf("Model '%s' is already being served.", servedModelName),
656+
})
657+
return
658+
}
659+
660+
srv.log.Infof("No existing job found for model '%s'. Starting a new job.", servedModelName)
661+
632662
cmdArgs := []string{
633663
"run", "--rm",
634664
fmt.Sprintf("--device=nvidia.com/gpu=%d", gpuIndex),
@@ -681,13 +711,14 @@ func (srv *ILabServer) runVllmContainerHandler(
681711

682712
// Create a Job record and store it in the DB
683713
newJob := &Job{
684-
JobID: jobID,
685-
Cmd: "podman",
686-
Args: cmdArgs,
687-
Status: "running",
688-
PID: cmd.Process.Pid,
689-
LogFile: logFilePath,
690-
StartTime: time.Now(),
714+
JobID: jobID,
715+
Cmd: "podman",
716+
Args: cmdArgs,
717+
Status: "running",
718+
PID: cmd.Process.Pid,
719+
LogFile: logFilePath,
720+
StartTime: time.Now(),
721+
ServedModelName: servedModelName,
691722
}
692723
if err := srv.createJob(newJob); err != nil {
693724
srv.log.Errorf("Failed to create job in DB for %s: %v", jobID, err)
@@ -859,6 +890,59 @@ func (srv *ILabServer) serveModelHandler(modelPath, port string, w http.Response
859890
_ = json.NewEncoder(w).Encode(map[string]string{"status": "model process started", "job_id": jobID})
860891
}
861892

893+
// getRunningJobByModel retrieves a running job for the specified served_model_name.
894+
// Returns nil if no such job exists.
895+
func (srv *ILabServer) getRunningJobByModel(servedModelName string) (*Job, error) {
896+
var job Job
897+
var argsJSON string
898+
var startTimeStr, endTimeStr sql.NullString
899+
900+
row := srv.db.QueryRow(`
901+
SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch, served_model_name
902+
FROM jobs
903+
WHERE served_model_name = ? AND status = 'running'
904+
LIMIT 1
905+
`, servedModelName)
906+
907+
err := row.Scan(
908+
&job.JobID,
909+
&job.Cmd,
910+
&argsJSON,
911+
&job.Status,
912+
&job.PID,
913+
&job.LogFile,
914+
&startTimeStr,
915+
&endTimeStr,
916+
&job.Branch,
917+
&job.ServedModelName,
918+
)
919+
if err == sql.ErrNoRows {
920+
return nil, nil
921+
} else if err != nil {
922+
return nil, err
923+
}
924+
925+
if err := json.Unmarshal([]byte(argsJSON), &job.Args); err != nil {
926+
srv.log.Errorf("Failed to unmarshal Args for job '%s': %v", job.JobID, err)
927+
return nil, fmt.Errorf("failed to unmarshal Args for job '%s': %v", job.JobID, err)
928+
}
929+
930+
if startTimeStr.Valid {
931+
t, err := time.Parse(time.RFC3339, startTimeStr.String)
932+
if err == nil {
933+
job.StartTime = t
934+
}
935+
}
936+
if endTimeStr.Valid && endTimeStr.String != "" {
937+
t, err := time.Parse(time.RFC3339, endTimeStr.String)
938+
if err == nil {
939+
job.EndTime = &t
940+
}
941+
}
942+
943+
return &job, nil
944+
}
945+
862946
// listServedModelJobIDsHandler is a debug endpoint to list current model to jobID mappings.
863947
func (srv *ILabServer) listServedModelJobIDsHandler(w http.ResponseWriter, r *http.Request) {
864948
srv.jobIDsMutex.RLock()

api-server/jobs.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ func (srv *ILabServer) initDB() {
3333
log_file TEXT,
3434
start_time TEXT,
3535
end_time TEXT,
36-
branch TEXT
36+
branch TEXT,
37+
served_model_name TEXT
3738
);
3839
`
3940
_, err = srv.db.Exec(createTableSQL)
@@ -58,8 +59,8 @@ func (srv *ILabServer) createJob(job *Job) error {
5859
endTimeStr = &s
5960
}
6061
_, err = srv.db.Exec(`
61-
INSERT INTO jobs (job_id, cmd, args, status, pid, log_file, start_time, end_time, branch)
62-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
62+
INSERT INTO jobs (job_id, cmd, args, status, pid, log_file, start_time, end_time, branch, served_model_name)
63+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
6364
`,
6465
job.JobID,
6566
job.Cmd,
@@ -70,6 +71,7 @@ func (srv *ILabServer) createJob(job *Job) error {
7071
job.StartTime.Format(time.RFC3339),
7172
endTimeStr,
7273
job.Branch,
74+
job.ServedModelName,
7375
)
7476
if err != nil {
7577
return fmt.Errorf("failed to insert job: %v", err)
@@ -79,7 +81,7 @@ func (srv *ILabServer) createJob(job *Job) error {
7981

8082
// getJob fetches a single job by job_id.
8183
func (srv *ILabServer) getJob(jobID string) (*Job, error) {
82-
row := srv.db.QueryRow("SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch FROM jobs WHERE job_id = ?", jobID)
84+
row := srv.db.QueryRow("SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch, served_model_name FROM jobs WHERE job_id = ?", jobID)
8385

8486
var j Job
8587
var argsJSON string
@@ -95,6 +97,7 @@ func (srv *ILabServer) getJob(jobID string) (*Job, error) {
9597
&startTimeStr,
9698
&endTimeStr,
9799
&j.Branch,
100+
&j.ServedModelName,
98101
)
99102
if err == sql.ErrNoRows {
100103
return nil, nil // not found
@@ -133,7 +136,7 @@ func (srv *ILabServer) updateJob(job *Job) error {
133136
}
134137
_, err = srv.db.Exec(`
135138
UPDATE jobs
136-
SET cmd = ?, args = ?, status = ?, pid = ?, log_file = ?, start_time = ?, end_time = ?, branch = ?
139+
SET cmd = ?, args = ?, status = ?, pid = ?, log_file = ?, start_time = ?, end_time = ?, branch = ?, served_model_name = ?
137140
WHERE job_id = ?
138141
`,
139142
job.Cmd,
@@ -144,6 +147,7 @@ func (srv *ILabServer) updateJob(job *Job) error {
144147
job.StartTime.Format(time.RFC3339),
145148
endTimeStr,
146149
job.Branch,
150+
job.ServedModelName,
147151
job.JobID,
148152
)
149153
if err != nil {

api-server/main.go

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,16 @@ type Data struct {
3939

4040
// Job represents a background job, including train/generate/pipeline/vllm-run jobs.
4141
type Job struct {
42-
JobID string `json:"job_id"`
43-
Cmd string `json:"cmd"`
44-
Args []string `json:"args"`
45-
Status string `json:"status"` // "running", "finished", "failed"
46-
PID int `json:"pid"`
47-
LogFile string `json:"log_file"`
48-
StartTime time.Time `json:"start_time"`
49-
EndTime *time.Time `json:"end_time,omitempty"`
50-
Branch string `json:"branch"`
42+
JobID string `json:"job_id"`
43+
Cmd string `json:"cmd"`
44+
Args []string `json:"args"`
45+
Status string `json:"status"` // "running", "finished", "failed"
46+
PID int `json:"pid"`
47+
LogFile string `json:"log_file"`
48+
StartTime time.Time `json:"start_time"`
49+
EndTime *time.Time `json:"end_time,omitempty"`
50+
Branch string `json:"branch"`
51+
ServedModelName string `json:"served_model_name"`
5152

5253
// Lock is not serialized; it protects updates to the Job in memory.
5354
Lock sync.Mutex `json:"-"`
@@ -94,7 +95,7 @@ type ILabServer struct {
9495
useVllm bool
9596
pipelineType string
9697
debugEnabled bool
97-
homeDir string // New field added
98+
homeDir string
9899

99100
// Logger
100101
logger *zap.Logger
@@ -119,12 +120,7 @@ type ILabServer struct {
119120
modelCache ModelCache
120121
}
121122

122-
// -----------------------------------------------------------------------------
123-
// main(), flags and Cobra
124-
// -----------------------------------------------------------------------------
125-
126123
func main() {
127-
// We create an instance of ILabServer to hold all state and methods.
128124
srv := &ILabServer{
129125
baseModel: "instructlab/granite-7b-lab",
130126
servedModelJobIDs: make(map[string]string),
@@ -135,7 +131,6 @@ func main() {
135131
Use: "ilab-server",
136132
Short: "ILab Server Application",
137133
Run: func(cmd *cobra.Command, args []string) {
138-
// Now that flags are set, run the server method on the struct.
139134
srv.runServer(cmd, args)
140135
},
141136
}
@@ -248,6 +243,8 @@ func (srv *ILabServer) runServer(cmd *cobra.Command, args []string) {
248243
// Initialize the model cache
249244
srv.initializeModelCache()
250245

246+
srv.reconstructServedModelJobIDs()
247+
251248
// Create the logs directory if it doesn't exist
252249
err = os.MkdirAll("logs", os.ModePerm)
253250
if err != nil {
@@ -348,6 +345,48 @@ func (srv *ILabServer) refreshModelCache() {
348345
srv.log.Infof("Model cache refreshed at %v with %d models.", srv.modelCache.Time, len(models))
349346
}
350347

348+
// reconstructServedModelJobIDs rebuilds the servedModelJobIDs map by querying the database
349+
func (srv *ILabServer) reconstructServedModelJobIDs() {
350+
srv.log.Info("Reconstructing servedModelJobIDs from the database...")
351+
352+
rows, err := srv.db.Query(`
353+
SELECT job_id, served_model_name
354+
FROM jobs
355+
WHERE cmd = 'podman' AND status = 'running'
356+
`)
357+
if err != nil {
358+
srv.log.Errorf("Error querying running vLLM jobs: %v", err)
359+
return
360+
}
361+
defer rows.Close()
362+
363+
for rows.Next() {
364+
var jobID, servedModelName string
365+
if err := rows.Scan(&jobID, &servedModelName); err != nil {
366+
srv.log.Errorf("Error scanning row: %v", err)
367+
continue
368+
}
369+
370+
// Validate servedModelName
371+
if servedModelName != "pre-train" && servedModelName != "post-train" {
372+
srv.log.Warnf("Invalid served_model_name '%s' for job_id '%s'", servedModelName, jobID)
373+
continue
374+
}
375+
376+
// Update the servedModelJobIDs map
377+
srv.jobIDsMutex.Lock()
378+
srv.servedModelJobIDs[servedModelName] = jobID
379+
srv.jobIDsMutex.Unlock()
380+
srv.log.Infof("Mapped model '%s' to job_id '%s'", servedModelName, jobID)
381+
}
382+
383+
if err := rows.Err(); err != nil {
384+
srv.log.Errorf("Error iterating over rows: %v", err)
385+
}
386+
387+
srv.log.Info("Reconstruction of servedModelJobIDs completed.")
388+
}
389+
351390
// -----------------------------------------------------------------------------
352391
// Start Generate Data Job
353392
// -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)