Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 93 additions & 29 deletions server/cmd/api/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,35 @@ package api

import (
"context"
"fmt"
"time"

"github.com/onkernel/kernel-images/server/lib/logger"
oapi "github.com/onkernel/kernel-images/server/lib/oapi"
"github.com/onkernel/kernel-images/server/lib/recorder"
)

// ApiService implements the API endpoints
// It manages a single recording session and provides endpoints for starting, stopping, and downloading it
type ApiService struct {
mainRecorderID string // ID used for the primary recording session
recordManager recorder.RecordManager
factory recorder.FFmpegRecorderFactory
// defaultRecorderID is used whenever the caller doesn't specify an explicit ID.
defaultRecorderID string

recordManager recorder.RecordManager
factory recorder.FFmpegRecorderFactory
}

func New(recordManager recorder.RecordManager, factory recorder.FFmpegRecorderFactory) *ApiService {
return &ApiService{
recordManager: recordManager,
factory: factory,
mainRecorderID: "main", // use a single recorder for now
func New(recordManager recorder.RecordManager, factory recorder.FFmpegRecorderFactory) (*ApiService, error) {
switch {
case recordManager == nil:
return nil, fmt.Errorf("recordManager cannot be nil")
case factory == nil:
return nil, fmt.Errorf("factory cannot be nil")
}

return &ApiService{
recordManager: recordManager,
factory: factory,
defaultRecorderID: "default",
}, nil
}

func (s *ApiService) StartRecording(ctx context.Context, req oapi.StartRecordingRequestObject) (oapi.StartRecordingResponseObject, error) {
Expand All @@ -31,26 +40,38 @@ func (s *ApiService) StartRecording(ctx context.Context, req oapi.StartRecording
if req.Body != nil {
params.FrameRate = req.Body.Framerate
params.MaxSizeInMB = req.Body.MaxFileSizeInMB
params.MaxDurationInSeconds = req.Body.MaxDurationInSeconds
}

// Determine recorder ID (use default if none provided)
recorderID := s.defaultRecorderID
if req.Body != nil && req.Body.Id != nil && *req.Body.Id != "" {
recorderID = *req.Body.Id
}

// Create, register, and start a new recorder
rec, err := s.factory(s.mainRecorderID, params)
rec, err := s.factory(recorderID, params)
if err != nil {
log.Error("failed to create recorder", "err", err)
log.Error("failed to create recorder", "err", err, "recorder_id", recorderID)
return oapi.StartRecording500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "failed to create recording"}}, nil
}
if err := s.recordManager.RegisterRecorder(ctx, rec); err != nil {
if rec, exists := s.recordManager.GetRecorder(s.mainRecorderID); exists && rec.IsRecording(ctx) {
log.Error("attempted to start recording while one is already active")
return oapi.StartRecording409JSONResponse{ConflictErrorJSONResponse: oapi.ConflictErrorJSONResponse{Message: "recording already in progress"}}, nil
if rec, exists := s.recordManager.GetRecorder(recorderID); exists {
if rec.IsRecording(ctx) {
log.Error("attempted to start recording while one is already active", "recorder_id", recorderID)
return oapi.StartRecording409JSONResponse{ConflictErrorJSONResponse: oapi.ConflictErrorJSONResponse{Message: "recording already in progress"}}, nil
} else {
log.Error("attempted to restart recording", "recorder_id", recorderID)
return oapi.StartRecording400JSONResponse{BadRequestErrorJSONResponse: oapi.BadRequestErrorJSONResponse{Message: "recording already completed"}}, nil
}
}
log.Error("failed to register recorder", "err", err)
log.Error("failed to register recorder", "err", err, "recorder_id", recorderID)
return oapi.StartRecording500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "failed to register recording"}}, nil
}

if err := rec.Start(ctx); err != nil {
log.Error("failed to start recording", "err", err)
// ensure the recorder is deregistered if we fail to start
log.Error("failed to start recording", "err", err, "recorder_id", recorderID)
// ensure the recorder is deregistered
defer s.recordManager.DeregisterRecorder(ctx, rec)
return oapi.StartRecording500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "failed to start recording"}}, nil
}
Expand All @@ -61,10 +82,19 @@ func (s *ApiService) StartRecording(ctx context.Context, req oapi.StartRecording
func (s *ApiService) StopRecording(ctx context.Context, req oapi.StopRecordingRequestObject) (oapi.StopRecordingResponseObject, error) {
log := logger.FromContext(ctx)

rec, exists := s.recordManager.GetRecorder(s.mainRecorderID)
if !exists || !rec.IsRecording(ctx) {
log.Warn("attempted to stop recording when none is active")
// Determine recorder ID
recorderID := s.defaultRecorderID
if req.Body != nil && req.Body.Id != nil && *req.Body.Id != "" {
recorderID = *req.Body.Id
}

rec, exists := s.recordManager.GetRecorder(recorderID)
if !exists {
log.Warn("attempted to stop recording when none is active", "recorder_id", recorderID)
return oapi.StopRecording400JSONResponse{BadRequestErrorJSONResponse: oapi.BadRequestErrorJSONResponse{Message: "no active recording to stop"}}, nil
} else if !rec.IsRecording(ctx) {
log.Warn("recording already stopped", "recorder_id", recorderID)
return oapi.StopRecording200Response{}, nil
}

// Check if force stop is requested
Expand All @@ -75,15 +105,15 @@ func (s *ApiService) StopRecording(ctx context.Context, req oapi.StopRecordingRe

var err error
if forceStop {
log.Info("force stopping recording")
log.Info("force stopping recording", "recorder_id", recorderID)
err = rec.ForceStop(ctx)
} else {
log.Info("gracefully stopping recording")
log.Info("gracefully stopping recording", "recorder_id", recorderID)
err = rec.Stop(ctx)
}

if err != nil {
log.Error("error occurred while stopping recording", "err", err, "force", forceStop)
log.Error("error occurred while stopping recording", "err", err, "force", forceStop, "recorder_id", recorderID)
}

return oapi.StopRecording200Response{}, nil
Expand All @@ -96,16 +126,22 @@ const (
func (s *ApiService) DownloadRecording(ctx context.Context, req oapi.DownloadRecordingRequestObject) (oapi.DownloadRecordingResponseObject, error) {
log := logger.FromContext(ctx)

// Determine recorder ID
recorderID := s.defaultRecorderID
if req.Params.Id != nil && *req.Params.Id != "" {
recorderID = *req.Params.Id
}

// Get the recorder to access its output path
rec, exists := s.recordManager.GetRecorder(s.mainRecorderID)
rec, exists := s.recordManager.GetRecorder(recorderID)
if !exists {
log.Error("attempted to download non-existent recording")
log.Error("attempted to download non-existent recording", "recorder_id", recorderID)
return oapi.DownloadRecording404JSONResponse{NotFoundErrorJSONResponse: oapi.NotFoundErrorJSONResponse{Message: "no recording found"}}, nil
}

out, meta, err := rec.Recording(ctx)
if err != nil {
log.Error("failed to get recording", "err", err)
log.Error("failed to get recording", "err", err, "recorder_id", recorderID)
return oapi.DownloadRecording500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "failed to get recording"}}, nil
}

Expand All @@ -118,13 +154,41 @@ func (s *ApiService) DownloadRecording(ctx context.Context, req oapi.DownloadRec
}, nil
}

log.Info("serving recording file for download", "size", meta.Size)
log.Info("serving recording file for download", "size", meta.Size, "recorder_id", recorderID)
return oapi.DownloadRecording200Videomp4Response{
Body: out,
Body: out,
Headers: oapi.DownloadRecording200ResponseHeaders{
XRecordingStartedAt: meta.StartTime.Format(time.RFC3339),
XRecordingFinishedAt: meta.EndTime.Format(time.RFC3339),
},
ContentLength: meta.Size,
}, nil
}

// ListRecorders returns a list of all registered recorders and whether each one is currently recording.
func (s *ApiService) ListRecorders(ctx context.Context, _ oapi.ListRecordersRequestObject) (oapi.ListRecordersResponseObject, error) {
infos := []oapi.RecorderInfo{}

timeOrNil := func(t time.Time) *time.Time {
if t.IsZero() {
return nil
}
return &t
}

recs := s.recordManager.ListActiveRecorders(ctx)
for _, r := range recs {
m := r.Metadata()
infos = append(infos, oapi.RecorderInfo{
Id: r.ID(),
IsRecording: r.IsRecording(ctx),
StartedAt: timeOrNil(m.StartTime),
FinishedAt: timeOrNil(m.EndTime),
})
}
return oapi.ListRecorders200JSONResponse(infos), nil
}

func (s *ApiService) Shutdown(ctx context.Context) error {
return s.recordManager.StopAll(ctx)
}
76 changes: 60 additions & 16 deletions server/cmd/api/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package api
import (
"bytes"
"context"
"fmt"
"io"
"math/rand"
"testing"

oapi "github.com/onkernel/kernel-images/server/lib/oapi"
"github.com/onkernel/kernel-images/server/lib/recorder"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -17,38 +19,70 @@ func TestApiService_StartRecording(t *testing.T) {

t.Run("success", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
svc := New(mgr, newMockFactory())
svc, err := New(mgr, newMockFactory())
require.NoError(t, err)

resp, err := svc.StartRecording(ctx, oapi.StartRecordingRequestObject{})
require.NoError(t, err)
require.IsType(t, oapi.StartRecording201Response{}, resp)

rec, exists := mgr.GetRecorder("main")
rec, exists := mgr.GetRecorder("default")
require.True(t, exists, "recorder was not registered")
require.True(t, rec.IsRecording(ctx), "recorder should be recording after Start")
})

t.Run("already recording", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
svc := New(mgr, newMockFactory())
svc, err := New(mgr, newMockFactory())
require.NoError(t, err)

// First start should succeed
_, err := svc.StartRecording(ctx, oapi.StartRecordingRequestObject{})
_, err = svc.StartRecording(ctx, oapi.StartRecordingRequestObject{})
require.NoError(t, err)

// Second start should return conflict
resp, err := svc.StartRecording(ctx, oapi.StartRecordingRequestObject{})
require.NoError(t, err)
require.IsType(t, oapi.StartRecording409JSONResponse{}, resp)
})

t.Run("custom ids don't collide", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
svc, err := New(mgr, newMockFactory())
require.NoError(t, err)

for i := 0; i < 5; i++ {
customID := fmt.Sprintf("rec-%d", i)
resp, err := svc.StartRecording(ctx, oapi.StartRecordingRequestObject{Body: &oapi.StartRecordingJSONRequestBody{Id: &customID}})
require.NoError(t, err)
require.IsType(t, oapi.StartRecording201Response{}, resp)

rec, exists := mgr.GetRecorder(customID)
assert.True(t, exists)
assert.True(t, rec.IsRecording(ctx))
}

out := mgr.ListActiveRecorders(ctx)
assert.Equal(t, 5, len(out))
for _, rec := range out {
assert.NotEqual(t, "default", rec.ID())
}

err = mgr.StopAll(ctx)
require.NoError(t, err)

out = mgr.ListActiveRecorders(ctx)
assert.Equal(t, 5, len(out))
})
}

func TestApiService_StopRecording(t *testing.T) {
ctx := context.Background()

t.Run("no active recording", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
svc := New(mgr, newMockFactory())
svc, err := New(mgr, newMockFactory())
require.NoError(t, err)

resp, err := svc.StopRecording(ctx, oapi.StopRecordingRequestObject{})
require.NoError(t, err)
Expand All @@ -57,10 +91,11 @@ func TestApiService_StopRecording(t *testing.T) {

t.Run("graceful stop", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
rec := &mockRecorder{id: "main", isRecordingFlag: true}
rec := &mockRecorder{id: "default", isRecordingFlag: true}
require.NoError(t, mgr.RegisterRecorder(ctx, rec), "failed to register recorder")

svc := New(mgr, newMockFactory())
svc, err := New(mgr, newMockFactory())
require.NoError(t, err)
resp, err := svc.StopRecording(ctx, oapi.StopRecordingRequestObject{})
require.NoError(t, err)
require.IsType(t, oapi.StopRecording200Response{}, resp)
Expand All @@ -69,12 +104,13 @@ func TestApiService_StopRecording(t *testing.T) {

t.Run("force stop", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
rec := &mockRecorder{id: "main", isRecordingFlag: true}
rec := &mockRecorder{id: "default", isRecordingFlag: true}
require.NoError(t, mgr.RegisterRecorder(ctx, rec), "failed to register recorder")

force := true
req := oapi.StopRecordingRequestObject{Body: &oapi.StopRecordingJSONRequestBody{ForceStop: &force}}
svc := New(mgr, newMockFactory())
svc, err := New(mgr, newMockFactory())
require.NoError(t, err)
resp, err := svc.StopRecording(ctx, req)
require.NoError(t, err)
require.IsType(t, oapi.StopRecording200Response{}, resp)
Expand All @@ -87,7 +123,8 @@ func TestApiService_DownloadRecording(t *testing.T) {

t.Run("not found", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
svc := New(mgr, newMockFactory())
svc, err := New(mgr, newMockFactory())
require.NoError(t, err)
resp, err := svc.DownloadRecording(ctx, oapi.DownloadRecordingRequestObject{})
require.NoError(t, err)
require.IsType(t, oapi.DownloadRecording404JSONResponse{}, resp)
Expand All @@ -103,10 +140,11 @@ func TestApiService_DownloadRecording(t *testing.T) {

t.Run("still recording", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
rec := &mockRecorder{id: "main", isRecordingFlag: true, recordingData: randomBytes(minRecordingSizeInBytes - 1)}
rec := &mockRecorder{id: "default", isRecordingFlag: true, recordingData: randomBytes(minRecordingSizeInBytes - 1)}
require.NoError(t, mgr.RegisterRecorder(ctx, rec), "failed to register recorder")

svc := New(mgr, newMockFactory())
svc, err := New(mgr, newMockFactory())
require.NoError(t, err)
// will return a 202 when the recording is too small
resp, err := svc.DownloadRecording(ctx, oapi.DownloadRecordingRequestObject{})
require.NoError(t, err)
Expand All @@ -132,10 +170,11 @@ func TestApiService_DownloadRecording(t *testing.T) {
t.Run("success", func(t *testing.T) {
mgr := recorder.NewFFmpegManager()
data := []byte("dummy video data")
rec := &mockRecorder{id: "main", recordingData: data}
rec := &mockRecorder{id: "default", recordingData: data}
require.NoError(t, mgr.RegisterRecorder(ctx, rec), "failed to register recorder")

svc := New(mgr, newMockFactory())
svc, err := New(mgr, newMockFactory())
require.NoError(t, err)
resp, err := svc.DownloadRecording(ctx, oapi.DownloadRecordingRequestObject{})
require.NoError(t, err)
r, ok := resp.(oapi.DownloadRecording200Videomp4Response)
Expand All @@ -151,10 +190,11 @@ func TestApiService_DownloadRecording(t *testing.T) {
func TestApiService_Shutdown(t *testing.T) {
ctx := context.Background()
mgr := recorder.NewFFmpegManager()
rec := &mockRecorder{id: "main", isRecordingFlag: true}
rec := &mockRecorder{id: "default", isRecordingFlag: true}
require.NoError(t, mgr.RegisterRecorder(ctx, rec), "failed to register recorder")

svc := New(mgr, newMockFactory())
svc, err := New(mgr, newMockFactory())
require.NoError(t, err)

require.NoError(t, svc.Shutdown(ctx))
require.True(t, rec.stopCalled, "Shutdown should have stopped active recorder")
Expand Down Expand Up @@ -219,6 +259,10 @@ func (m *mockRecorder) Recording(ctx context.Context) (io.ReadCloser, *recorder.
return reader, meta, nil
}

func (m *mockRecorder) Metadata() *recorder.RecordingMetadata {
return &recorder.RecordingMetadata{}
}

func newMockFactory() recorder.FFmpegRecorderFactory {
return func(id string, _ recorder.FFmpegRecordingParams) (recorder.Recorder, error) {
rec := &mockRecorder{id: id}
Expand Down
Loading
Loading