diff --git a/cmd/launcher/web/api/api.go b/cmd/launcher/web/api/api.go index 75d828f95..8245d1ab1 100644 --- a/cmd/launcher/web/api/api.go +++ b/cmd/launcher/web/api/api.go @@ -70,8 +70,19 @@ func (a *apiLauncher) UserMessage(webURL string, printer func(v ...any)) { // SetupSubrouters adds the API router to the parent router. func (a *apiLauncher) SetupSubrouters(router *mux.Router, config *launcher.Config) error { + // Create the ADK REST API config + adkrestConfig := &adkrest.Config{ + AgentLoader: config.AgentLoader, + SessionService: config.SessionService, + ArtifactService: config.ArtifactService, + MemoryService: config.MemoryService, + SSEWriteTimeout: a.config.sseWriteTimeout, + } // Create the ADK REST API handler - apiHandler := adkrest.NewHandler(config, a.config.sseWriteTimeout) + apiHandler, err := adkrest.NewHandler(adkrestConfig) + if err != nil { + return fmt.Errorf("failed to create ADK REST API handler: %v", err) + } // Wrap it with CORS middleware corsHandler := corsWithArgs(a.config.frontendAddress)(apiHandler) diff --git a/examples/rest/main.go b/examples/rest/main.go index 1e9a6bfbe..756966a1e 100644 --- a/examples/rest/main.go +++ b/examples/rest/main.go @@ -26,7 +26,8 @@ import ( "google.golang.org/adk/agent" "google.golang.org/adk/agent/llmagent" - "google.golang.org/adk/cmd/launcher" + "google.golang.org/adk/artifact" + "google.golang.org/adk/memory" "google.golang.org/adk/model/gemini" "google.golang.org/adk/server/adkrest" "google.golang.org/adk/session" @@ -60,13 +61,19 @@ func main() { } // Configure the ADK REST API - config := &launcher.Config{ - AgentLoader: agent.NewSingleLoader(a), - SessionService: session.InMemoryService(), + config := &adkrest.Config{ + SessionService: session.InMemoryService(), + ArtifactService: artifact.InMemoryService(), + AgentLoader: agent.NewSingleLoader(a), + MemoryService: memory.InMemoryService(), + SSEWriteTimeout: 120 * time.Second, } // Create the REST API handler - this returns a standard http.Handler - apiHandler := adkrest.NewHandler(config, 120*time.Second) + apiHandler, err := adkrest.NewHandler(config) + if err != nil { + log.Fatalf("Failed to create ADK REST API handler: %v", err) + } // Create a standard net/http ServeMux mux := http.NewServeMux() diff --git a/server/adkrest/config.go b/server/adkrest/config.go new file mode 100644 index 000000000..83b24e093 --- /dev/null +++ b/server/adkrest/config.go @@ -0,0 +1,55 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adkrest + +import ( + "errors" + "time" + + "google.golang.org/adk/agent" + "google.golang.org/adk/artifact" + "google.golang.org/adk/memory" + "google.golang.org/adk/runner" + "google.golang.org/adk/session" + "google.golang.org/adk/telemetry" +) + +// Config defines the services and loaders required by the adkrest package. +type Config struct { + SessionService session.Service + ArtifactService artifact.Service + AgentLoader agent.Loader + MemoryService memory.Service + SSEWriteTimeout time.Duration + PluginConfig runner.PluginConfig + TelemetryOptions []telemetry.Option +} + +// validate validates the config +func (c *Config) validate() error { + if c.SessionService == nil { + return errors.New("session service is required") + } + if c.ArtifactService == nil { + return errors.New("artifact service is required") + } + if c.AgentLoader == nil { + return errors.New("agent loader is required") + } + if c.MemoryService == nil { + return errors.New("memory service is required") + } + return nil +} diff --git a/server/adkrest/handler.go b/server/adkrest/handler.go index ba2b88f90..c5e5a7475 100644 --- a/server/adkrest/handler.go +++ b/server/adkrest/handler.go @@ -16,11 +16,9 @@ package adkrest import ( "net/http" - "time" "github.com/gorilla/mux" - "google.golang.org/adk/cmd/launcher" "google.golang.org/adk/server/adkrest/controllers" "google.golang.org/adk/server/adkrest/internal/routers" "google.golang.org/adk/server/adkrest/internal/services" @@ -28,7 +26,11 @@ import ( ) // NewHandler creates and returns an http.Handler for the ADK REST API. -func NewHandler(config *launcher.Config, sseWriteTimeout time.Duration) http.Handler { +func NewHandler(config *Config) (http.Handler, error) { + if err := config.validate(); err != nil { + return nil, err + } + debugTelemetry := services.NewDebugTelemetry() config.TelemetryOptions = append(config.TelemetryOptions, telemetry.WithSpanProcessors(debugTelemetry.SpanProcessor())) config.TelemetryOptions = append(config.TelemetryOptions, telemetry.WithLogRecordProcessors(debugTelemetry.LogProcessor())) @@ -38,13 +40,13 @@ func NewHandler(config *launcher.Config, sseWriteTimeout time.Duration) http.Han // where the ADK REST API will be served. setupRouter(router, routers.NewSessionsAPIRouter(controllers.NewSessionsAPIController(config.SessionService)), - routers.NewRuntimeAPIRouter(controllers.NewRuntimeAPIController(config.SessionService, config.MemoryService, config.AgentLoader, config.ArtifactService, sseWriteTimeout, config.PluginConfig)), + routers.NewRuntimeAPIRouter(controllers.NewRuntimeAPIController(config.SessionService, config.MemoryService, config.AgentLoader, config.ArtifactService, config.SSEWriteTimeout, config.PluginConfig)), routers.NewAppsAPIRouter(controllers.NewAppsAPIController(config.AgentLoader)), routers.NewDebugAPIRouter(controllers.NewDebugAPIController(config.SessionService, config.AgentLoader, debugTelemetry)), routers.NewArtifactsAPIRouter(controllers.NewArtifactsAPIController(config.ArtifactService)), &routers.EvalAPIRouter{}, ) - return router + return router, nil } func setupRouter(router *mux.Router, subrouters ...routers.Router) *mux.Router {