Skip to content

Commit 97dc0df

Browse files
alexcreasyakram
andauthored
Adds user_token auth mode to llama stack module (opendatahub-io#4596)
* Updates authentication to follow pattern used by model registry. Now supports user_token mode. * fix: properly disable authentication if auth_method=disabled * refactor: optimize TokenClientFactory performance and make API path prefix configurable - Move TokenClientFactory creation from per-request to app-level initialization - Add configurable APIPathPrefix to EnvConfig with default /api/v1 - Add --api-path-prefix command-line flag and API_PATH_PREFIX env var - Replace hardcoded ApiPathPrefix constants with dynamic path generation methods - Update middleware to use app.tokenFactory instead of recreating per request - Update isAPIRoute function to be App method using configurable prefix - Update all tests to work with new configurable approach - Maintain backward compatibility with default /api/v1 prefix This improves performance by eliminating per-request object creation and provides flexibility for different deployment scenarios. --------- Co-authored-by: Akram Ben Aissi <akram.benaissi@gmail.com>
1 parent 24aa71b commit 97dc0df

File tree

14 files changed

+286
-477
lines changed

14 files changed

+286
-477
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
vendor/
2+
pkg/

frontend/packages/llama-stack-modular-ui/bff/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ LOG_LEVEL ?= debug
77
ALLOWED_ORIGINS ?= ""
88
LLAMA_STACK_URL ?= ""
99
MOCK_LS_CLIENT ?= false
10+
AUTH_METHOD ?= "disabled"
1011

1112
.PHONY: all
1213
all: build
@@ -47,7 +48,7 @@ build: fmt vet test ## Builds the project to produce a binary executable.
4748
.PHONY: run
4849
run: fmt vet envtest ## Runs the project.
4950
ENVTEST_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" \
50-
go run ./cmd --port=$(PORT) --static-assets-dir=$(STATIC_ASSETS_DIR) --log-level=$(LOG_LEVEL) --allowed-origins=$(ALLOWED_ORIGINS) --llama-stack-url=$(LLAMA_STACK_URL) --mock-ls-client=$(MOCK_LS_CLIENT)
51+
go run ./cmd --port=$(PORT) --static-assets-dir=$(STATIC_ASSETS_DIR) --log-level=$(LOG_LEVEL) --allowed-origins=$(ALLOWED_ORIGINS) --llama-stack-url=$(LLAMA_STACK_URL) --mock-ls-client=$(MOCK_LS_CLIENT) --auth-method=$(AUTH_METHOD)
5152

5253
##@ Dependencies
5354

frontend/packages/llama-stack-modular-ui/bff/cmd/helpers.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ func getEnvAsString(name string, defaultVal string) string {
2424
return defaultVal
2525
}
2626

27+
// TODO: remove nolint comment below when we use this method
28+
//
29+
//nolint:unused
2730
func getEnvAsBool(name string, defaultVal bool) bool {
2831
if value, exists := os.LookupEnv(name); exists {
2932
boolValue, err := strconv.ParseBool(value)

frontend/packages/llama-stack-modular-ui/bff/cmd/main.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,14 @@ func main() {
2323
flag.TextVar(&cfg.LogLevel, "log-level", parseLevel(getEnvAsString("LOG_LEVEL", "DEBUG")), "Sets server log level, possible values: error, warn, info, debug")
2424
flag.Func("allowed-origins", "Sets allowed origins for CORS purposes, accepts a comma separated list of origins or * to allow all, default none", newOriginParser(&cfg.AllowedOrigins, getEnvAsString("ALLOWED_ORIGINS", "")))
2525
flag.BoolVar(&cfg.MockLSClient, "mock-ls-client", false, "Use mock Llama Stack client")
26+
flag.StringVar(&cfg.AuthMethod, "auth-method", "disabled", "Authentication method (disabled or user_token)")
27+
flag.StringVar(&cfg.AuthTokenHeader, "auth-token-header", getEnvAsString("AUTH_TOKEN_HEADER", config.DefaultAuthTokenHeader), "Header used to extract the token (e.g., Authorization)")
28+
flag.StringVar(&cfg.AuthTokenPrefix, "auth-token-prefix", getEnvAsString("AUTH_TOKEN_PREFIX", config.DefaultAuthTokenPrefix), "Prefix used in the token header (e.g., 'Bearer ')")
29+
flag.StringVar(&cfg.APIPathPrefix, "api-path-prefix", getEnvAsString("API_PATH_PREFIX", "/api/v1"), "API path prefix for BFF endpoints (e.g., /api/v1)")
2630

2731
// Llama Stack configuration
2832
flag.StringVar(&cfg.LlamaStackURL, "llama-stack-url", getEnvAsString("LLAMA_STACK_URL", ""), "Llama Stack server URL for proxying requests")
2933

30-
// OAuth configuration
31-
flag.BoolVar(&cfg.OAuthEnabled, "oauth-enabled", getEnvAsBool("OAUTH_ENABLED", false), "Enable OAuth authentication")
32-
flag.StringVar(&cfg.OAuthClientID, "oauth-client-id", getEnvAsString("OAUTH_CLIENT_ID", ""), "OAuth client ID")
33-
flag.StringVar(&cfg.OAuthClientSecret, "oauth-client-secret", getEnvAsString("OAUTH_CLIENT_SECRET", ""), "OAuth client secret")
34-
flag.StringVar(&cfg.OAuthRedirectURI, "oauth-redirect-uri", getEnvAsString("OAUTH_REDIRECT_URI", ""), "OAuth redirect URI")
35-
flag.StringVar(&cfg.OAuthServerURL, "oauth-server-url", getEnvAsString("OAUTH_SERVER_URL", ""), "OAuth server URL")
36-
flag.StringVar(&cfg.OpenShiftApiServerUrl, "openshift-api-server-url", getEnvAsString("OPENSHIFT_API_SERVER_URL", "https://kubernetes.default.svc.cluster.local"), "OpenShift API server URL for token validation")
37-
flag.StringVar(&cfg.OAuthUserInfoEndpoint, "oauth-user-info-endpoint", getEnvAsString("OAUTH_USER_INFO_ENDPOINT", ""), "OAuth user info endpoint URL for token validation (optional, defaults to OpenShift API server + /apis/user.openshift.io/v1/users/~)")
38-
3934
flag.Parse()
4035

4136
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{

frontend/packages/llama-stack-modular-ui/bff/internal/api/app.go

Lines changed: 33 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,47 +13,53 @@ import (
1313
"github.com/julienschmidt/httprouter"
1414
"github.com/opendatahub-io/llama-stack-modular-ui/internal/config"
1515
helper "github.com/opendatahub-io/llama-stack-modular-ui/internal/helpers"
16+
"github.com/opendatahub-io/llama-stack-modular-ui/internal/integrations"
1617
)
1718

1819
const (
1920
Version = "1.0.0"
2021

21-
ApiPathPrefix = "/api/v1"
2222
HealthCheckPath = "/healthcheck"
23-
24-
OauthCallbackPath = ApiPathPrefix + "/auth/callback"
25-
OauthStatePath = ApiPathPrefix + "/auth/state"
26-
27-
ConfigPath = ApiPathPrefix + "/config"
28-
29-
ModelListPath = ApiPathPrefix + "/models"
30-
VectorDBListPath = ApiPathPrefix + "/vector-dbs"
31-
32-
// making it simpler than /tool-runtime/rag-tool/insert
33-
UploadPath = ApiPathPrefix + "/upload"
34-
QueryPath = ApiPathPrefix + "/query"
3523
)
3624

3725
// isAPIRoute checks if the given path is an API route
38-
func isAPIRoute(path string) bool {
26+
func (app *App) isAPIRoute(path string) bool {
3927
return path == HealthCheckPath ||
4028
path == OpenAPIPath ||
4129
path == OpenAPIJSONPath ||
4230
path == OpenAPIYAMLPath ||
4331
path == SwaggerUIPath ||
44-
// Match exactly “/api/v1” or any sub-path under it
45-
path == ApiPathPrefix ||
46-
strings.HasPrefix(path, ApiPathPrefix+"/") ||
32+
// Match exactly the configured API path prefix or any sub-path under it
33+
path == app.config.APIPathPrefix ||
34+
strings.HasPrefix(path, app.config.APIPathPrefix+"/") ||
4735
// Similarly for the llama-stack prefix
4836
path == "/llama-stack" ||
4937
strings.HasPrefix(path, "/llama-stack/")
5038
}
5139

40+
// Path generation methods for configurable API paths
41+
func (app *App) getModelListPath() string {
42+
return app.config.APIPathPrefix + "/models"
43+
}
44+
45+
func (app *App) getVectorDBListPath() string {
46+
return app.config.APIPathPrefix + "/vector-dbs"
47+
}
48+
49+
func (app *App) getUploadPath() string {
50+
return app.config.APIPathPrefix + "/upload"
51+
}
52+
53+
func (app *App) getQueryPath() string {
54+
return app.config.APIPathPrefix + "/query"
55+
}
56+
5257
type App struct {
5358
config config.EnvConfig
5459
logger *slog.Logger
5560
repositories *repositories.Repositories
5661
openAPI *OpenAPIHandler
62+
tokenFactory *integrations.TokenClientFactory
5763
}
5864

5965
func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) {
@@ -62,25 +68,6 @@ func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) {
6268

6369
logger.Info("Initializing app with config", slog.Any("config", cfg))
6470

65-
// Validate OAuth configuration
66-
if cfg.OAuthEnabled {
67-
if cfg.OAuthServerURL == "" {
68-
return nil, fmt.Errorf("OAUTH_SERVER_URL is required when OAuth is enabled")
69-
}
70-
if cfg.OAuthClientID == "" {
71-
return nil, fmt.Errorf("OAUTH_CLIENT_ID is required when OAuth is enabled")
72-
}
73-
if cfg.OAuthClientSecret == "" {
74-
return nil, fmt.Errorf("OAUTH_CLIENT_SECRET is required when OAuth is enabled")
75-
}
76-
if cfg.OAuthRedirectURI == "" {
77-
return nil, fmt.Errorf("OAUTH_REDIRECT_URI is required when OAuth is enabled")
78-
}
79-
logger.Info("OAuth configuration validated",
80-
slog.String("oauth_server_url", cfg.OAuthServerURL),
81-
slog.String("openshift_api_server_url", cfg.OpenShiftApiServerUrl))
82-
}
83-
8471
if cfg.MockLSClient {
8572
lsClient, err = mocks.NewLlamastackClientMock()
8673
} else {
@@ -102,6 +89,7 @@ func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) {
10289
logger: logger,
10390
repositories: repositories.NewRepositories(lsClient),
10491
openAPI: openAPIHandler,
92+
tokenFactory: integrations.NewTokenClientFactory(logger, cfg),
10593
}
10694
return app, nil
10795
}
@@ -113,43 +101,19 @@ func (app *App) Routes() http.Handler {
113101
apiRouter.NotFound = http.HandlerFunc(app.notFoundResponse)
114102
apiRouter.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse)
115103

116-
// OAuth routes
117-
if app.config.OAuthEnabled {
118-
apiRouter.POST(OauthCallbackPath, app.HandleOAuthCallback)
119-
apiRouter.GET(OauthStatePath, app.HandleOAuthState)
120-
}
121-
122-
// Config endpoint (not authenticated)
123-
apiRouter.GET(ConfigPath, app.HandleConfig)
124-
125-
apiRouter.GET(ModelListPath, app.RequireAuthRoute(app.AttachRESTClient(app.GetAllModelsHandler)))
126-
apiRouter.GET(VectorDBListPath, app.RequireAuthRoute(app.AttachRESTClient(app.GetAllVectorDBsHandler)))
104+
apiRouter.GET(app.getModelListPath(), app.RequireAccessToService(app.AttachRESTClient(app.GetAllModelsHandler)))
105+
apiRouter.GET(app.getVectorDBListPath(), app.RequireAccessToService(app.AttachRESTClient(app.GetAllVectorDBsHandler)))
127106

128107
// POST to register the vectorDB (/v1/vector-dbs)
129-
apiRouter.POST(VectorDBListPath, app.RequireAuthRoute(app.AttachRESTClient(app.RegisterVectorDBHandler)))
130-
apiRouter.POST(UploadPath, app.RequireAuthRoute(app.AttachRESTClient(app.UploadHandler)))
131-
apiRouter.POST(QueryPath, app.RequireAuthRoute(app.AttachRESTClient(app.QueryHandler)))
108+
apiRouter.POST(app.getVectorDBListPath(), app.RequireAccessToService(app.AttachRESTClient(app.RegisterVectorDBHandler)))
109+
apiRouter.POST(app.getUploadPath(), app.RequireAccessToService(app.AttachRESTClient(app.UploadHandler)))
110+
apiRouter.POST(app.getQueryPath(), app.RequireAccessToService(app.AttachRESTClient(app.QueryHandler)))
132111

133112
// App Router
134113
appMux := http.NewServeMux()
135114

136-
//// Register /api/v1/config as a public endpoint
137-
//appMux.HandleFunc(ApiPathPrefix+"/config", func(w http.ResponseWriter, r *http.Request) {
138-
// app.HandleConfig(w, r, nil)
139-
//})
140-
//
141-
//// Register /api/v1/auth/callback as a public endpoint
142-
//appMux.HandleFunc(ApiPathPrefix+"/auth/callback", func(w http.ResponseWriter, r *http.Request) {
143-
// app.HandleOAuthCallback(w, r, nil)
144-
//})
145-
//
146-
//// Register /api/v1/auth/state as a public endpoint
147-
//appMux.HandleFunc(ApiPathPrefix+"/auth/state", func(w http.ResponseWriter, r *http.Request) {
148-
// app.HandleOAuthState(w, r, nil)
149-
//})
150-
151-
//All other /api/v1/* routes require auth
152-
appMux.Handle(ApiPathPrefix+"/", apiRouter)
115+
//All other API routes require auth
116+
appMux.Handle(app.config.APIPathPrefix+"/", apiRouter)
153117

154118
// Llama Stack proxy handler (unprotected)
155119
appMux.HandleFunc("/llama-stack/", app.HandleLlamaStackProxy)
@@ -164,7 +128,7 @@ func (app *App) Routes() http.Handler {
164128

165129
// Skip API routes
166130
if (r.URL.Path == "/" || r.URL.Path == "/index.html") ||
167-
(len(r.URL.Path) > 0 && r.URL.Path[0] == '/' && !isAPIRoute(r.URL.Path)) {
131+
(len(r.URL.Path) > 0 && r.URL.Path[0] == '/' && !app.isAPIRoute(r.URL.Path)) {
168132

169133
// Check if the requested file exists
170134
cleanPath := path.Clean(r.URL.Path)
@@ -200,7 +164,7 @@ func (app *App) Routes() http.Handler {
200164
combinedMux.HandleFunc(OpenAPIYAMLPath, app.openAPI.HandleOpenAPIYAMLWrapper)
201165
combinedMux.HandleFunc(SwaggerUIPath, app.openAPI.HandleSwaggerUIWrapper)
202166

203-
combinedMux.Handle("/", app.RecoverPanic(app.EnableTelemetry(app.EnableCORS(appMux))))
167+
combinedMux.Handle("/", app.RecoverPanic(app.EnableTelemetry(app.EnableCORS(app.InjectRequestIdentity(appMux)))))
204168

205169
return combinedMux
206170
}

frontend/packages/llama-stack-modular-ui/bff/internal/api/app_test.go

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
package api
22

33
import (
4+
"log/slog"
45
"testing"
56

7+
"github.com/opendatahub-io/llama-stack-modular-ui/internal/config"
68
"github.com/stretchr/testify/assert"
79
)
810

911
func TestIsAPIRoute(t *testing.T) {
12+
// Create a test app with default API path prefix
13+
cfg := config.EnvConfig{
14+
APIPathPrefix: "/api/v1",
15+
}
16+
app := &App{
17+
config: cfg,
18+
logger: slog.Default(),
19+
}
20+
1021
tests := []struct {
1122
name string
1223
path string
@@ -49,42 +60,42 @@ func TestIsAPIRoute(t *testing.T) {
4960
// API v1 routes - exact matches
5061
{
5162
name: "api v1 prefix exact match",
52-
path: ApiPathPrefix,
63+
path: "/api/v1",
5364
expected: true,
5465
},
5566
{
5667
name: "api v1 config",
57-
path: ApiPathPrefix + "/config",
68+
path: "/api/v1/config",
5869
expected: true,
5970
},
6071
{
6172
name: "api v1 models",
62-
path: ApiPathPrefix + "/models",
73+
path: "/api/v1/models",
6374
expected: true,
6475
},
6576
{
6677
name: "api v1 vector-dbs",
67-
path: ApiPathPrefix + "/vector-dbs",
78+
path: "/api/v1/vector-dbs",
6879
expected: true,
6980
},
7081
{
7182
name: "api v1 upload",
72-
path: ApiPathPrefix + "/upload",
83+
path: "/api/v1/upload",
7384
expected: true,
7485
},
7586
{
7687
name: "api v1 query",
77-
path: ApiPathPrefix + "/query",
88+
path: "/api/v1/query",
7889
expected: true,
7990
},
8091
{
8192
name: "api v1 auth callback",
82-
path: ApiPathPrefix + "/auth/callback",
93+
path: "/api/v1/auth/callback",
8394
expected: true,
8495
},
8596
{
8697
name: "api v1 auth state",
87-
path: ApiPathPrefix + "/auth/state",
98+
path: "/api/v1/auth/state",
8899
expected: true,
89100
},
90101

@@ -226,7 +237,7 @@ func TestIsAPIRoute(t *testing.T) {
226237

227238
for _, tt := range tests {
228239
t.Run(tt.name, func(t *testing.T) {
229-
result := isAPIRoute(tt.path)
240+
result := app.isAPIRoute(tt.path)
230241
assert.Equal(t, tt.expected, result,
231242
"Path: %s, Expected: %v, Got: %v", tt.path, tt.expected, result)
232243
})
@@ -235,6 +246,14 @@ func TestIsAPIRoute(t *testing.T) {
235246

236247
// TestIsAPIRouteEdgeCases tests additional edge cases and boundary conditions
237248
func TestIsAPIRouteEdgeCases(t *testing.T) {
249+
// Create a test app with default API path prefix
250+
cfg := config.EnvConfig{
251+
APIPathPrefix: "/api/v1",
252+
}
253+
app := &App{
254+
config: cfg,
255+
logger: slog.Default(),
256+
}
238257
tests := []struct {
239258
name string
240259
path string
@@ -304,7 +323,7 @@ func TestIsAPIRouteEdgeCases(t *testing.T) {
304323

305324
for _, tt := range tests {
306325
t.Run(tt.name, func(t *testing.T) {
307-
result := isAPIRoute(tt.path)
326+
result := app.isAPIRoute(tt.path)
308327
assert.Equal(t, tt.expected, result,
309328
"Path: %s, Expected: %v, Got: %v", tt.path, tt.expected, result)
310329
})
@@ -313,13 +332,22 @@ func TestIsAPIRouteEdgeCases(t *testing.T) {
313332

314333
// TestIsAPIRoutePerformance tests that the function handles various path lengths efficiently
315334
func TestIsAPIRoutePerformance(t *testing.T) {
335+
// Create a test app with default API path prefix
336+
cfg := config.EnvConfig{
337+
APIPathPrefix: "/api/v1",
338+
}
339+
app := &App{
340+
config: cfg,
341+
logger: slog.Default(),
342+
}
343+
316344
// Test with very long paths to ensure no performance issues
317345
longPath := "/api/v1/" + string(make([]byte, 1000))
318-
result := isAPIRoute(longPath)
346+
result := app.isAPIRoute(longPath)
319347
assert.True(t, result, "Long API path should still be recognized as API route")
320348

321349
// Test with very long non-API paths
322350
longNonAPIPath := "/dashboard/" + string(make([]byte, 1000))
323-
result = isAPIRoute(longNonAPIPath)
351+
result = app.isAPIRoute(longNonAPIPath)
324352
assert.False(t, result, "Long non-API path should not be recognized as API route")
325353
}

frontend/packages/llama-stack-modular-ui/bff/internal/api/config_handler.go

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)