diff --git a/byoc/job_orchestrator.go b/byoc/job_orchestrator.go index 159f9d83d3..210474992e 100644 --- a/byoc/job_orchestrator.go +++ b/byoc/job_orchestrator.go @@ -48,9 +48,11 @@ func (bs *BYOCOrchestratorServer) RegisterCapability() http.Handler { return } defer r.Body.Close() - extCapSettings := string(body) remoteAddr := getRemoteAddr(r) + // The request body contains the capability settings JSON with the token field + extCapSettings := string(body) + cap, err := orch.RegisterExternalCapability(extCapSettings) w.Header().Set("Content-Type", "application/json") @@ -63,7 +65,7 @@ func (bs *BYOCOrchestratorServer) RegisterCapability() http.Handler { w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) - clog.Infof(context.TODO(), "registered capability remoteAddr=%v capability=%v url=%v price=%v", remoteAddr, cap.Name, cap.Url, big.NewRat(cap.PricePerUnit, cap.PriceScaling)) + clog.Infof(context.TODO(), "registered capability remoteAddr=%v capability=%v url=%v price=%v auth_token=%v", remoteAddr, cap.Name, cap.Url, big.NewRat(cap.PricePerUnit, cap.PriceScaling), cap.AuthToken != "") }) } @@ -271,6 +273,13 @@ func (bso *BYOCOrchestratorServer) processJob(ctx context.Context, w http.Respon req.Header.Add("Content-Length", r.Header.Get("Content-Length")) req.Header.Add("Content-Type", r.Header.Get("Content-Type")) + // Add Authorization header if auth token is set for this capability + if extCap, ok := bso.node.ExternalCapabilities.Capabilities[orchJob.Req.Capability]; ok { + if extCap.AuthToken != "" { + req.Header.Add("Authorization", "Bearer "+extCap.AuthToken) + } + } + start := time.Now() resp, err := sendReqWithTimeout(req, time.Duration(orchJob.Req.Timeout)*time.Second) if err != nil { @@ -288,6 +297,16 @@ func (bso *BYOCOrchestratorServer) processJob(ctx context.Context, w http.Respon return } + // Check for 401 Unauthorized - remove capability so worker can re-register with correct token + if resp.StatusCode == http.StatusUnauthorized { + clog.Errorf(ctx, "received 401 Unauthorized from worker, removing capability %v", orchJob.Req.Capability) + bso.orch.RemoveExternalCapability(orchJob.Req.Capability) + bso.chargeForCompute(start, orchJob.JobPrice, orchJob.Sender, orchJob.Req.Capability) + w.Header().Set(jobPaymentBalanceHdr, bso.getPaymentBalance(orchJob.Sender, orchJob.Req.Capability).FloatString(0)) + http.Error(w, "job not able to be processed, removing capability err=worker auth token failed", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", resp.Header.Get("Content-Type")) w.Header().Set("X-Metadata", resp.Header.Get("X-Metadata")) diff --git a/byoc/job_orchestrator_test.go b/byoc/job_orchestrator_test.go index 9dbe8fd1fd..55253d9669 100644 --- a/byoc/job_orchestrator_test.go +++ b/byoc/job_orchestrator_test.go @@ -784,6 +784,89 @@ func TestProcessJob_MethodNotAllowed(t *testing.T) { assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) } +func TestProcessJob_WorkerAuthFailed(t *testing.T) { + // Mock worker that returns 401 Unauthorized + workerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "worker auth token failed"}`)) + })) + defer workerServer.Close() + + mockVerifySig := func(addr ethcommon.Address, msg string, sig []byte) bool { + return true + } + mockGetUrlForCapability := func(capability string) string { + return workerServer.URL + } + mockJobPriceInfo := func(addr ethcommon.Address, cap string) (*net.PriceInfo, error) { + return &net.PriceInfo{ + PricePerUnit: 0, + PixelsPerUnit: 1, + }, nil + } + + var removeCapCalled bool + mockRemoveExternalCapability := func(string) error { + removeCapCalled = true + return nil + } + + mockOrch := newMockJobOrchestrator() + mockOrch.verifySignature = mockVerifySig + mockOrch.getUrlForCapability = mockGetUrlForCapability + mockOrch.jobPriceInfo = mockJobPriceInfo + mockOrch.unregisterExternalCapability = mockRemoveExternalCapability + + bso := &BYOCOrchestratorServer{ + node: mockOrch.node, + orch: mockOrch, + } + + // Prepare job request + jobParams := JobParameters{EnableVideoIngress: true, EnableVideoEgress: true, EnableDataOutput: true} + paramsStr := marshalToString(t, jobParams) + jobReq := &JobRequest{ + ID: "test-job", + Capability: "test-capability", + Parameters: paramsStr, + Timeout: 70, + Request: "{}", + } + + orchJob := &orchJob{Req: jobReq, Params: &jobParams} + gatewayJob := &gatewayJob{Job: orchJob, node: mockOrch.node} + + // Setup signing - sign the request using the gateway job pattern + mockOrch.node.OrchestratorPool = newStubOrchestratorPool(mockOrch.node, []string{workerServer.URL}) + gatewayJob.sign() + mockOrch.node.OrchestratorPool = nil + + // Make POST request to /process/request/ endpoint + req := httptest.NewRequest(http.MethodPost, "/process/request/test-capability", bytes.NewReader([]byte("{}"))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(jobRequestHdr, gatewayJob.SignedJobReq) + + w := httptest.NewRecorder() + handler := bso.ProcessJob() + handler.ServeHTTP(w, req) + + resp := w.Result() + + // Verify response is 500 Internal Server Error (as per job_orchestrator.go line 306) + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.Contains(t, w.Body.String(), "worker auth token failed") + + // Verify capability was removed due to 401 + assert.True(t, removeCapCalled, "RemoveExternalCapability should have been called for 401") + + // Note: FreeExternalCapabilityCapacity is NOT called for 401 responses + // because the function returns early (job_orchestrator.go lines 300-308) + // before reaching the defer statement at line 315 + + workerServer.CloseClientConnections() +} + func TestProcessPayment(t *testing.T) { ctx := context.Background() diff --git a/byoc/stream_orchestrator.go b/byoc/stream_orchestrator.go index 14a8b97f73..359e269a61 100644 --- a/byoc/stream_orchestrator.go +++ b/byoc/stream_orchestrator.go @@ -149,12 +149,16 @@ func (bso *BYOCOrchestratorServer) StartStream() http.Handler { return } - req, err := http.NewRequestWithContext(ctx, "POST", workerRoute, bytes.NewBuffer(reqBodyBytes)) + req, err := bso.createWorkerReq(ctx, workerRoute, orchJob.Req.Capability, orchJob.Req.ID, bytes.NewBuffer(reqBodyBytes)) + if err != nil { + clog.Errorf(ctx, "failed to create worker request err=%v", err) + respondWithError(w, "Failed to create worker request", http.StatusInternalServerError) + failedToStartStream = true + return + } // set the headers req.Header.Add("Content-Length", r.Header.Get("Content-Length")) req.Header.Add("Content-Type", r.Header.Get("Content-Type")) - // use for routing to worker if reverse proxy in front of workers - req.Header.Add("X-Stream-Id", orchJob.Req.ID) start := time.Now() resp, err := sendReqWithTimeout(req, time.Duration(orchJob.Req.Timeout)*time.Second) @@ -165,23 +169,12 @@ func (bso *BYOCOrchestratorServer) StartStream() http.Handler { return } - respBody, err := io.ReadAll(resp.Body) - if err != nil { - clog.Errorf(ctx, "Error reading response body: %v", err) - respondWithError(w, "Error reading response body", http.StatusInternalServerError) - failedToStartStream = true - return - } - defer resp.Body.Close() - - //error response from worker but assume can retry and pass along error response and status code - if resp.StatusCode > 399 { - clog.Errorf(ctx, "error processing stream start request statusCode=%d", resp.StatusCode) - + statusCode, respBody := bso.processWorkerResp(ctx, orchJob.Req.Capability, resp) + if statusCode > 399 { bso.chargeForCompute(start, orchJob.JobPrice, orchJob.Sender, orchJob.Req.Capability) w.Header().Set(jobPaymentBalanceHdr, bso.getPaymentBalance(orchJob.Sender, orchJob.Req.Capability).FloatString(0)) //return error response from the worker - w.WriteHeader(resp.StatusCode) + w.WriteHeader(statusCode) w.Write(respBody) failedToStartStream = true return @@ -281,8 +274,11 @@ func (bso *BYOCOrchestratorServer) monitorOrchStream(job *orchJob) { // if not, send stop to worker and exit monitoring stream, exists := bso.node.ExternalCapabilities.GetStream(streamID) if !exists { - req, err := http.NewRequestWithContext(ctx, "POST", job.Req.CapabilityUrl+"/stream/stop", nil) - // set the headers + req, err := bso.createWorkerReq(ctx, job.Req.CapabilityUrl+"/stream/stop", job.Req.Capability, streamID, nil) + if err != nil { + clog.Errorf(ctx, "Error creating request to worker %v: %v", job.Req.CapabilityUrl, err) + return + } resp, err := sendReqWithTimeout(req, time.Duration(job.Req.Timeout)*time.Second) if err != nil { clog.Errorf(ctx, "Error sending request to worker %v: %v", job.Req.CapabilityUrl, err) @@ -331,32 +327,21 @@ func (bso *BYOCOrchestratorServer) StopStream() http.Handler { r.Body.Close() workerRoute := orchJob.Req.CapabilityUrl + "/stream/stop" - req, err := http.NewRequestWithContext(ctx, "POST", workerRoute, bytes.NewBuffer(body)) + req, err := bso.createWorkerReq(ctx, workerRoute, orchJob.Req.Capability, jobDetails.StreamId, bytes.NewBuffer(body)) if err != nil { clog.Errorf(ctx, "failed to create /stream/stop request to worker err=%v", err) http.Error(w, err.Error(), http.StatusBadRequest) return } - // use for routing to worker if reverse proxy in front of workers - req.Header.Add("X-Stream-Id", jobDetails.StreamId) resp, err := sendReqWithTimeout(req, time.Duration(orchJob.Req.Timeout)*time.Second) if err != nil { clog.Errorf(ctx, "Error sending request to worker %v: %v", workerRoute, err) } var respBody []byte - respStatusCode := http.StatusOK // default to 200, if not nill will be overwritten + respStatusCode := http.StatusOK // default to 200, if not nil will be overwritten if resp != nil { - respBody, err = io.ReadAll(resp.Body) - if err != nil { - clog.Errorf(ctx, "Error reading response body: %v", err) - } - defer resp.Body.Close() - - respStatusCode = resp.StatusCode - if resp.StatusCode > 399 { - clog.Errorf(ctx, "error processing stream stop request statusCode=%d", resp.StatusCode) - } + respStatusCode, respBody = bso.processWorkerResp(ctx, orchJob.Req.Capability, resp) } // Stop the stream and free capacity @@ -393,15 +378,13 @@ func (bso *BYOCOrchestratorServer) UpdateStream() http.Handler { r.Body.Close() workerRoute := orchJob.Req.CapabilityUrl + "/stream/params" - req, err := http.NewRequestWithContext(ctx, "POST", workerRoute, bytes.NewBuffer(body)) + req, err := bso.createWorkerReq(ctx, workerRoute, orchJob.Req.Capability, jobDetails.StreamId, bytes.NewBuffer(body)) if err != nil { clog.Errorf(ctx, "failed to create /stream/params request to worker err=%v", err) http.Error(w, err.Error(), http.StatusBadRequest) return } req.Header.Add("Content-Type", "application/json") - // use for routing to worker if reverse proxy in front of workers - req.Header.Add("X-Stream-Id", jobDetails.StreamId) resp, err := sendReqWithTimeout(req, time.Duration(orchJob.Req.Timeout)*time.Second) if err != nil { @@ -410,21 +393,61 @@ func (bso *BYOCOrchestratorServer) UpdateStream() http.Handler { return } - respBody, err := io.ReadAll(resp.Body) - if err != nil { - clog.Errorf(ctx, "Error reading response body: %v", err) - respondWithError(w, "Error reading response body", http.StatusInternalServerError) - return + statusCode, respBody := bso.processWorkerResp(ctx, orchJob.Req.Capability, resp) + + w.WriteHeader(statusCode) + w.Write(respBody) + }) +} + +// createWorkerReq creates an HTTP request to send to the worker. +// handles setting stream id and auth headers for worker +func (bso *BYOCOrchestratorServer) createWorkerReq(ctx context.Context, workerRoute, capability, streamId string, body io.Reader) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, "POST", workerRoute, body) + if err != nil { + return nil, err + } + + // Add stream ID header for routing to worker if reverse proxy in front of workers + if streamId != "" { + req.Header.Add("X-Stream-Id", streamId) + } + + // Add Authorization header if auth token is set for this capability + if extCap, ok := bso.node.ExternalCapabilities.Capabilities[capability]; ok { + if extCap.AuthToken != "" { + req.Header.Add("Authorization", "Bearer "+extCap.AuthToken) } - defer resp.Body.Close() + } + + return req, nil +} + +// processWorkerResp processes the worker response and returns the statusCode and respBody. +// It handles 401 Unauthorized responses by removing the capability. +func (bso *BYOCOrchestratorServer) processWorkerResp(ctx context.Context, capability string, resp *http.Response) (int, []byte) { + statusCode := resp.StatusCode + respBody, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + clog.Errorf(ctx, "Error reading response body: %v", err) + return http.StatusInternalServerError, []byte("Error reading response body") + } - if resp.StatusCode > 399 { - clog.Errorf(ctx, "error processing stream update request statusCode=%d", resp.StatusCode) + if statusCode > 399 { + clog.Errorf(ctx, "error processing stream request statusCode=%d", statusCode) + + // Check for 401 Unauthorized - remove capability so worker can re-register with correct token + // return 500 error to the gateway. Gateway will move on to another Orchestrator if available. + if statusCode == http.StatusUnauthorized { + clog.Errorf(ctx, "received 401 Unauthorized from worker, removing capability %v", capability) + bso.orch.RemoveExternalCapability(capability) + statusCode = http.StatusInternalServerError + respBody = []byte("Orchestrator worker failure") } + } - w.WriteHeader(resp.StatusCode) - w.Write(respBody) - }) + return statusCode, respBody } func (bso *BYOCOrchestratorServer) ProcessStreamPayment() http.Handler { diff --git a/byoc/stream_test.go b/byoc/stream_test.go index 46ff463a6d..569d4500ed 100644 --- a/byoc/stream_test.go +++ b/byoc/stream_test.go @@ -1669,8 +1669,6 @@ func TestGetStreamRequestParams(t *testing.T) { }) } -// TestStartStreamWorkerErrorResponse tests the error response handling from worker -// when worker returns status code > 399 (lines 154-182 in stream_orchestrator.go) func TestStartStreamWorkerErrorResponse(t *testing.T) { // Mock worker that returns 400 Bad Request statusCodeReturned := http.StatusBadRequest @@ -1816,6 +1814,44 @@ func TestStartStreamWorkerErrorResponse(t *testing.T) { }) }) + t.Run("WorkerReturns401_Unauthorized", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + statusCodeReturned = http.StatusUnauthorized + freeCapacityCalled = false + + var removeCapCalled bool + mockRemoveCap := func(string) error { + removeCapCalled = true + return nil + } + mockOrch.unregisterExternalCapability = mockRemoveCap + + req := httptest.NewRequest(http.MethodPost, "/process/stream/start", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + // Set up job request header + req.Header.Set(jobRequestHdr, gatewayJob.SignedJobReq) + + w := httptest.NewRecorder() + handler := bso.StartStream() + handler.ServeHTTP(w, req) + + // Verify 500 error received after catch/change at Orchestrator + assert.Equal(t, http.StatusInternalServerError, w.Code) + + // Verify capability was removed + assert.True(t, removeCapCalled, "RemoveExternalCapability should have been called for 401") + + // Verify freeCapacity was called + assert.True(t, freeCapacityCalled, "FreeExternalCapabilityCapacity should have been called") + + server.CloseClientConnections() + + // no stream created + assert.Zero(t, len(mockOrch.node.ExternalCapabilities.Streams)) + }) + }) + t.Run("WorkerReturns500_FatalError", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { statusCodeReturned = http.StatusInternalServerError diff --git a/core/external_capabilities.go b/core/external_capabilities.go index 4a6c165570..27ba79e45f 100644 --- a/core/external_capabilities.go +++ b/core/external_capabilities.go @@ -22,6 +22,7 @@ type ExternalCapability struct { PricePerUnit int64 `json:"price_per_unit"` PriceScaling int64 `json:"price_scaling"` PriceCurrency string `json:"currency"` + AuthToken string `json:"token"` price *AutoConvertedPrice @@ -222,6 +223,7 @@ func (extCaps *ExternalCapabilities) RegisterCapability(extCapability string) (* cap.Url = extCap.Url cap.Capacity = extCap.Capacity cap.price = extCap.price + cap.AuthToken = extCap.AuthToken } extCaps.Capabilities[extCap.Name] = &extCap