diff --git a/server/cmd/api/api/api.go b/server/cmd/api/api/api.go index 36aa3acb..13381327 100644 --- a/server/cmd/api/api/api.go +++ b/server/cmd/api/api/api.go @@ -2,12 +2,19 @@ package api import ( "context" + "encoding/json" "errors" "fmt" + "net" + "net/http" + "net/url" "os" + "strconv" + "strings" "sync" "time" + "github.com/gorilla/websocket" "github.com/onkernel/kernel-images/server/lib/logger" oapi "github.com/onkernel/kernel-images/server/lib/oapi" "github.com/onkernel/kernel-images/server/lib/recorder" @@ -28,8 +35,200 @@ type ApiService struct { procs map[string]*processHandle } +// We're extending the StrictServerInterface to include our new endpoint var _ oapi.StrictServerInterface = (*ApiService)(nil) +// SetScreenResolution endpoint +// (GET /screen/resolution) +// IsWebSocketAvailable checks if a WebSocket connection can be established to the given URL +func isWebSocketAvailable(wsURL string) bool { + // First check if we can establish a TCP connection by parsing the URL + u, err := url.Parse(wsURL) + if err != nil { + return false + } + + // Get host and port + host := u.Host + if !strings.Contains(host, ":") { + // Add default port based on scheme + if u.Scheme == "ws" { + host = host + ":80" + } else if u.Scheme == "wss" { + host = host + ":443" + } + } + + // Try TCP connection + conn, err := net.DialTimeout("tcp", host, 200*time.Millisecond) + if err != nil { + return false + } + conn.Close() + + // Try WebSocket connection + dialer := websocket.Dialer{ + HandshakeTimeout: 200 * time.Millisecond, + } + + wsConn, _, err := dialer.Dial(wsURL, nil) + if err != nil { + return false + } + defer wsConn.Close() + + return true +} + +// GetWebSocketURL determines the appropriate WebSocket URL from an HTTP request +// It can be used in tests +func getWebSocketURL(r *http.Request) string { + // Auth parameters for WS connection + authParams := "?password=admin&username=kernel" + + // Default local development URL - will try only in local dev + localDevURL := "ws://localhost:8080/ws" + authParams + + // In tests or other cases where request is nil + if r == nil { + return localDevURL + } + + log := logger.FromContext(r.Context()) + + // Get URL components from the request + scheme := "ws" + if r.TLS != nil || strings.HasPrefix(r.Proto, "HTTPS") || r.Header.Get("X-Forwarded-Proto") == "https" { + scheme = "wss" + } + + // Get host from request header, strip the port if present + // This is crucial for production where we don't want ports in WS URLs + host := r.Host + if host == "" { + log.Warn("empty host in request, using fallback mechanisms") + + // Try the internal WebSocket endpoint + internalURL := "ws://127.0.0.1:8080/ws" + authParams + log.Info("trying internal WebSocket URL", "url", internalURL) + + // If it fails, return the URL anyway since we need to return something + return internalURL + } + + // Remove port from host if present (critical for production) + if hostParts := strings.Split(host, ":"); len(hostParts) > 1 { + host = hostParts[0] + } + + // Determine the base path by removing screen/resolution if present + basePath := r.URL.Path + for len(basePath) > 0 && basePath[len(basePath)-1] == '/' { + basePath = basePath[:len(basePath)-1] + } + + if len(basePath) >= 18 && basePath[len(basePath)-18:] == "/screen/resolution" { + basePath = basePath[:len(basePath)-18] + } + + // Construct WebSocket URL with auth parameters, but NO PORT + wsURL := fmt.Sprintf("%s://%s%s/ws%s", scheme, host, basePath, authParams) + + // For localhost requests in development, default to the known working port + if strings.Contains(host, "localhost") { + // In development, we use a specific port for WebSocket + wsURL = fmt.Sprintf("ws://localhost:8080/ws%s", authParams) + log.Info("localhost detected, using development WebSocket URL", "url", wsURL) + return wsURL + } + + log.Info("using host-based WebSocket URL", "url", wsURL) + return wsURL +} + +func (s *ApiService) SetScreenResolutionHandler(w http.ResponseWriter, r *http.Request) { + // Parse query parameters + width := 0 + height := 0 + var rate *int + + // Calculate the WebSocket URL from the request + wsURL := getWebSocketURL(r) + + // Parse width + widthStr := r.URL.Query().Get("width") + if widthStr == "" { + http.Error(w, "missing required query parameter: width", http.StatusBadRequest) + return + } + var err error + width, err = strconv.Atoi(widthStr) + if err != nil { + http.Error(w, "invalid width parameter: must be an integer", http.StatusBadRequest) + return + } + + // Parse height + heightStr := r.URL.Query().Get("height") + if heightStr == "" { + http.Error(w, "missing required query parameter: height", http.StatusBadRequest) + return + } + height, err = strconv.Atoi(heightStr) + if err != nil { + http.Error(w, "invalid height parameter: must be an integer", http.StatusBadRequest) + return + } + + // Parse optional rate parameter + rateStr := r.URL.Query().Get("rate") + if rateStr != "" { + rateVal, err := strconv.Atoi(rateStr) + if err != nil { + http.Error(w, "invalid rate parameter: must be an integer", http.StatusBadRequest) + return + } + rate = &rateVal + } + + // Create request object + reqObj := SetScreenResolutionRequestObject{ + Width: width, + Height: height, + Rate: rate, + WSURL: wsURL, + } + + // Call the actual implementation + resp, err := s.SetScreenResolution(r.Context(), reqObj) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Handle different response types + switch r := resp.(type) { + case SetScreenResolution200JSONResponse: + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(r) + case SetScreenResolution400JSONResponse: + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(r) + case SetScreenResolution409JSONResponse: + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusConflict) + json.NewEncoder(w).Encode(r) + case SetScreenResolution500JSONResponse: + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(r) + default: + http.Error(w, "unexpected response type", http.StatusInternalServerError) + } +} + func New(recordManager recorder.RecordManager, factory recorder.FFmpegRecorderFactory) (*ApiService, error) { switch { case recordManager == nil: diff --git a/server/cmd/api/api/computer.go b/server/cmd/api/api/computer.go index 7d8b8c04..6b739d6e 100644 --- a/server/cmd/api/api/computer.go +++ b/server/cmd/api/api/computer.go @@ -2,9 +2,15 @@ package api import ( "context" + "encoding/json" "fmt" + "net/url" + "os/exec" "strconv" + "strings" + "time" + "github.com/gorilla/websocket" "github.com/onkernel/kernel-images/server/lib/logger" oapi "github.com/onkernel/kernel-images/server/lib/oapi" ) @@ -54,6 +60,226 @@ func (s *ApiService) MoveMouse(ctx context.Context, request oapi.MoveMouseReques return oapi.MoveMouse200Response{}, nil } +// Define interface types for our new endpoint +// These should match the structure expected by the generated code + +// SetScreenResolutionParams represents query parameters for our endpoint +type SetScreenResolutionParams struct { + Width int + Height int + Rate *int +} + +// For testing +type SetScreenResolutionFunc func(ctx context.Context, req SetScreenResolutionRequestObject) (SetScreenResolutionResponseObject, error) + +// This would be auto-generated by oapi-codegen, but we're defining it manually +type SetScreenResolutionRequestObject struct { + Width int // Required query parameter + Height int // Required query parameter + Rate *int // Optional query parameter + WSURL string // WebSocket URL (calculated in handler) +} + +// Response types for different status codes +type SetScreenResolution200JSONResponse struct { + Ok bool `json:"ok"` +} + +type SetScreenResolution400JSONResponse struct { + Message string `json:"message"` +} + +type SetScreenResolution409JSONResponse struct { + Message string `json:"message"` +} + +type SetScreenResolution500JSONResponse struct { + Message string `json:"message"` +} + +// Union type for all possible responses +type SetScreenResolutionResponseObject interface { + SetScreenResolutionResponse() +} + +// Implement response interface for each response type +func (SetScreenResolution200JSONResponse) SetScreenResolutionResponse() {} +func (SetScreenResolution400JSONResponse) SetScreenResolutionResponse() {} +func (SetScreenResolution409JSONResponse) SetScreenResolutionResponse() {} +func (SetScreenResolution500JSONResponse) SetScreenResolutionResponse() {} + +func (s *ApiService) SetScreenResolution(ctx context.Context, request SetScreenResolutionRequestObject) (SetScreenResolutionResponseObject, error) { + log := logger.FromContext(ctx) + + // Validate parameters + width := request.Width + height := request.Height + rate := request.Rate + + // Parameters were already validated in OpenAPI spec, but we'll do a sanity check here + if width < 200 || width > 8000 { + return SetScreenResolution400JSONResponse{ + Message: fmt.Sprintf("width must be between 200 and 8000, got %d", width), + }, nil + } + + if height < 200 || height > 8000 { + return SetScreenResolution400JSONResponse{ + Message: fmt.Sprintf("height must be between 200 and 8000, got %d", height), + }, nil + } + + if rate != nil && (*rate < 24 || *rate > 240) { + return SetScreenResolution400JSONResponse{ + Message: fmt.Sprintf("rate must be between 24 and 240, got %d", *rate), + }, nil + } + + // Check if ffmpeg is running (indicating an active recording) + cmd := exec.Command("pgrep", "ffmpeg") + if err := cmd.Run(); err == nil { + // ffmpeg is running + return SetScreenResolution409JSONResponse{ + Message: "detected ongoing replay recording process, close the recording first before switching resolution", + }, nil + } + + // Get the WebSocket URL from the request + wsURL := request.WSURL + + // Prepare multiple fallback URLs for different environments + fallbackURLs := []string{ + // Internal container URL (direct to the WS service) + "ws://127.0.0.1:8080/ws?password=admin&username=kernel", + // Docker service name URL (for container networking) + "ws://browser:8080/ws?password=admin&username=kernel", + // Local development fallback + "ws://localhost:8080/ws?password=admin&username=kernel", + } + + // Check if it's a host-based URL from production + isProduction := !strings.Contains(wsURL, "localhost") && !strings.Contains(wsURL, "127.0.0.1") + + // For production URLs, ensure we're using the format from the Vue client logs + if isProduction { + // Parse and fix the URL if needed + if parsedURL, err := url.Parse(wsURL); err == nil { + // Remove port from host if present + host := parsedURL.Host + if hostParts := strings.Split(host, ":"); len(hostParts) > 1 { + parsedURL.Host = hostParts[0] + // Update the wsURL without the port + wsURL = parsedURL.String() + log.Info("fixed production WebSocket URL by removing port", "url", wsURL) + } + } + } + + // Try all possible URLs + var conn *websocket.Conn + var dialErr error + + // Create a dialer with appropriate timeouts + dialer := websocket.Dialer{ + HandshakeTimeout: 3 * time.Second, + } + + // First try the primary URL + log.Info("trying primary websocket URL", "url", wsURL) + conn, _, dialErr = dialer.Dial(wsURL, nil) + + // If successful, use this connection + if dialErr == nil { + log.Info("successfully connected to primary WebSocket URL", "url", wsURL) + } else { + log.Warn("primary websocket URL failed", "url", wsURL, "err", dialErr) + + // Try each fallback URL + for _, fallbackURL := range fallbackURLs { + // Don't retry the same URL + if fallbackURL == wsURL { + continue + } + + log.Info("trying fallback websocket URL", "url", fallbackURL) + conn, _, dialErr = dialer.Dial(fallbackURL, nil) + + if dialErr == nil { + log.Info("successfully connected to fallback WebSocket URL", "url", fallbackURL) + break + } else { + log.Warn("fallback websocket URL failed", "url", fallbackURL, "err", dialErr) + } + } + + // If all attempts failed + if dialErr != nil { + log.Error("all websocket connection attempts failed") + return SetScreenResolution500JSONResponse{ + Message: "failed to connect to websocket server after multiple attempts", + }, nil + } + } + + // Ensure connection is closed when we're done, like wscat -c '...' -x '...' would do + defer func() { + log.Info("closing websocket connection") + // Send close message for clean shutdown + closeMsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + err := conn.WriteControl(websocket.CloseMessage, closeMsg, time.Now().Add(time.Second)) + if err != nil { + log.Warn("failed to send close message", "err", err) + } + conn.Close() + }() + + log.Info("successfully connected to websocket", "url", wsURL) + + // Prepare message + message := map[string]interface{}{ + "event": "screen/set", + "width": width, + "height": height, + } + + // Add rate if provided + if rate != nil { + message["rate"] = *rate + } + + // Serialize message to JSON + messageJSON, err := json.Marshal(message) + if err != nil { + log.Error("failed to marshal JSON message", "err", err) + return SetScreenResolution500JSONResponse{ + Message: "failed to prepare websocket message", + }, nil + } + + // Send message + if err := conn.WriteMessage(websocket.TextMessage, messageJSON); err != nil { + log.Error("failed to send websocket message", "err", err) + return SetScreenResolution500JSONResponse{ + Message: "failed to send command to websocket server", + }, nil + } + + // Wait for response with short timeout + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + _, response, err := conn.ReadMessage() + if err != nil { + log.Warn("did not receive websocket response, but proceeding", "err", err) + // Continue anyway since we don't know if the server responds + } else { + log.Info("received websocket response", "response", string(response)) + } + + return SetScreenResolution200JSONResponse{ + Ok: true, + }, nil +} + func (s *ApiService) ClickMouse(ctx context.Context, request oapi.ClickMouseRequestObject) (oapi.ClickMouseResponseObject, error) { log := logger.FromContext(ctx) diff --git a/server/cmd/api/api/computer_test.go b/server/cmd/api/api/computer_test.go new file mode 100644 index 00000000..98dfca67 --- /dev/null +++ b/server/cmd/api/api/computer_test.go @@ -0,0 +1,179 @@ +package api + +import ( + "crypto/tls" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestScreenResolutionParameterValidation(t *testing.T) { + // Test parameter validation in the SetScreenResolution function + testCases := []struct { + name string + width int + height int + rate *int + expectError bool + errorMsg string + }{ + { + name: "valid parameters", + width: 1920, + height: 1080, + rate: intPtr(60), + expectError: false, + }, + { + name: "valid without rate", + width: 1280, + height: 720, + rate: nil, + expectError: false, + }, + { + name: "width too small", + width: 100, + height: 1080, + rate: nil, + expectError: true, + errorMsg: "width must be between 200 and 8000", + }, + { + name: "width too large", + width: 9000, + height: 1080, + rate: nil, + expectError: true, + errorMsg: "width must be between 200 and 8000", + }, + { + name: "height too small", + width: 1920, + height: 100, + rate: nil, + expectError: true, + errorMsg: "height must be between 200 and 8000", + }, + { + name: "height too large", + width: 1920, + height: 9000, + rate: nil, + expectError: true, + errorMsg: "height must be between 200 and 8000", + }, + { + name: "rate too small", + width: 1920, + height: 1080, + rate: intPtr(10), + expectError: true, + errorMsg: "rate must be between 24 and 240", + }, + { + name: "rate too large", + width: 1920, + height: 1080, + rate: intPtr(300), + expectError: true, + errorMsg: "rate must be between 24 and 240", + }, + } + + // Create stub request object + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := SetScreenResolutionRequestObject{ + Width: tc.width, + Height: tc.height, + Rate: tc.rate, + } + + // Just test the validation part + if req.Width < 200 || req.Width > 8000 { + assert.True(t, tc.expectError, "Expected validation error for width") + } + + if req.Height < 200 || req.Height > 8000 { + assert.True(t, tc.expectError, "Expected validation error for height") + } + + if req.Rate != nil && (*req.Rate < 24 || *req.Rate > 240) { + assert.True(t, tc.expectError, "Expected validation error for rate") + } + }) + } +} + +// Helper function to create int pointer +func intPtr(i int) *int { + return &i +} + +func TestGetWebSocketURL(t *testing.T) { + testCases := []struct { + name string + request *http.Request + expectedURL string + }{ + { + name: "nil request", + request: nil, + expectedURL: "ws://localhost:8080/ws?password=admin&username=kernel", + }, + { + name: "standard http request", + request: &http.Request{ + Host: "example.com", + URL: &url.URL{ + Path: "/screen/resolution", + }, + TLS: nil, + }, + expectedURL: "ws://example.com/ws?password=admin&username=kernel", + }, + { + name: "https request", + request: &http.Request{ + Host: "example.com", + URL: &url.URL{ + Path: "/screen/resolution", + }, + TLS: &tls.ConnectionState{}, + }, + expectedURL: "wss://example.com/ws?password=admin&username=kernel", + }, + { + name: "request with path prefix", + request: &http.Request{ + Host: "example.com", + URL: &url.URL{ + Path: "/api/v1/screen/resolution", + }, + TLS: nil, + }, + expectedURL: "ws://example.com/api/v1/ws?password=admin&username=kernel", + }, + { + name: "request with trailing slash", + request: &http.Request{ + Host: "example.com", + URL: &url.URL{ + Path: "/api/v1/screen/resolution/", + }, + TLS: nil, + }, + expectedURL: "ws://example.com/api/v1/ws?password=admin&username=kernel", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + url := getWebSocketURL(tc.request) + assert.Equal(t, tc.expectedURL, url) + }) + } +} diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index 57806754..30a86162 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -82,6 +82,9 @@ func main() { strictHandler := oapi.NewStrictHandler(apiService, nil) oapi.HandlerFromMux(strictHandler, r) + // Register our custom endpoint handler + r.Get("/screen/resolution", apiService.SetScreenResolutionHandler) + // endpoints to expose the spec r.Get("/spec.yaml", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/vnd.oai.openapi") diff --git a/server/openapi.yaml b/server/openapi.yaml index 4f4f8e00..5c84bf84 100644 --- a/server/openapi.yaml +++ b/server/openapi.yaml @@ -737,6 +737,53 @@ paths: $ref: "#/components/responses/NotFoundError" "500": $ref: "#/components/responses/InternalError" + + /screen/resolution: + get: + summary: Set screen resolution + operationId: setScreenResolution + parameters: + - name: width + in: query + required: true + schema: + type: integer + minimum: 200 + maximum: 8000 + description: Screen width in pixels + - name: height + in: query + required: true + schema: + type: integer + minimum: 200 + maximum: 8000 + description: Screen height in pixels + - name: rate + in: query + required: false + schema: + type: integer + minimum: 24 + maximum: 240 + description: Screen refresh rate in Hz + responses: + "200": + description: Screen resolution set successfully + content: + application/json: + schema: + $ref: "#/components/schemas/OkResponse" + "400": + $ref: "#/components/responses/BadRequestError" + "409": + description: Cannot set resolution due to ongoing recording + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "500": + $ref: "#/components/responses/InternalError" components: schemas: StartRecordingRequest: