18
18
19
19
use std:: fs:: File ;
20
20
use std:: io:: BufReader ;
21
+ use std:: sync:: Arc ;
21
22
22
23
use actix_cors:: Cors ;
23
- use actix_web:: { web, App , HttpServer } ;
24
+ use actix_web:: {
25
+ web:: { self , resource} ,
26
+ App , HttpServer ,
27
+ } ;
24
28
use actix_web_prometheus:: PrometheusMetrics ;
25
29
use actix_web_static_files:: ResourceFiles ;
30
+ use log:: info;
31
+ use openid:: Discovered ;
26
32
use rustls:: { Certificate , PrivateKey , ServerConfig } ;
27
33
use rustls_pemfile:: { certs, pkcs8_private_keys} ;
28
34
@@ -37,33 +43,47 @@ mod ingest;
37
43
mod llm;
38
44
mod logstream;
39
45
mod middleware;
46
+ mod oidc;
40
47
mod query;
41
48
mod rbac;
49
+ mod role;
42
50
43
51
include ! ( concat!( env!( "OUT_DIR" ) , "/generated.rs" ) ) ;
44
52
45
53
const MAX_EVENT_PAYLOAD_SIZE : usize = 10485760 ;
46
54
const API_BASE_PATH : & str = "/api" ;
47
55
const API_VERSION : & str = "v1" ;
48
56
49
- #[ macro_export]
50
- macro_rules! create_app {
51
- ( $prometheus: expr) => {
57
+ pub async fn run_http (
58
+ prometheus : PrometheusMetrics ,
59
+ oidc_client : Option < crate :: oidc:: OpenidConfig > ,
60
+ ) -> anyhow:: Result < ( ) > {
61
+ let oidc_client = match oidc_client {
62
+ Some ( config) => {
63
+ let client = config
64
+ . connect ( & format ! ( "{API_BASE_PATH}/{API_VERSION}/o/code" ) )
65
+ . await ?;
66
+ Some ( Arc :: new ( client) )
67
+ }
68
+ None => None ,
69
+ } ;
70
+
71
+ let create_app = move || {
52
72
App :: new ( )
53
- . wrap( $ prometheus. clone( ) )
54
- . configure( |cfg| configure_routes( cfg) )
73
+ . wrap ( prometheus. clone ( ) )
74
+ . configure ( |cfg| configure_routes ( cfg, oidc_client . clone ( ) ) )
55
75
. wrap ( actix_web:: middleware:: Logger :: default ( ) )
56
76
. wrap ( actix_web:: middleware:: Compress :: default ( ) )
57
77
. wrap (
58
78
Cors :: default ( )
59
79
. allow_any_header ( )
60
80
. allow_any_method ( )
61
- . allow_any_origin( ) ,
81
+ . allow_any_origin ( )
82
+ . expose_any_header ( )
83
+ . supports_credentials ( ) ,
62
84
)
63
85
} ;
64
- }
65
86
66
- pub async fn run_http ( prometheus : PrometheusMetrics ) -> anyhow:: Result < ( ) > {
67
87
let ssl_acceptor = match (
68
88
& CONFIG . parseable . tls_cert_path ,
69
89
& CONFIG . parseable . tls_key_path ,
@@ -99,7 +119,7 @@ pub async fn run_http(prometheus: PrometheusMetrics) -> anyhow::Result<()> {
99
119
} ;
100
120
101
121
// concurrent workers equal to number of cores on the cpu
102
- let http_server = HttpServer :: new ( move || create_app ! ( prometheus ) ) . workers ( num_cpus:: get ( ) ) ;
122
+ let http_server = HttpServer :: new ( create_app) . workers ( num_cpus:: get ( ) ) ;
103
123
if let Some ( config) = ssl_acceptor {
104
124
http_server
105
125
. bind_rustls ( & CONFIG . parseable . address , config) ?
@@ -112,7 +132,10 @@ pub async fn run_http(prometheus: PrometheusMetrics) -> anyhow::Result<()> {
112
132
Ok ( ( ) )
113
133
}
114
134
115
- pub fn configure_routes ( cfg : & mut web:: ServiceConfig ) {
135
+ pub fn configure_routes (
136
+ cfg : & mut web:: ServiceConfig ,
137
+ oidc_client : Option < Arc < openid:: Client < Discovered , crate :: oidc:: Claims > > > ,
138
+ ) {
116
139
let generated = generate ( ) ;
117
140
118
141
//log stream API
@@ -211,13 +234,13 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) {
211
234
. route (
212
235
web:: put ( )
213
236
. to ( rbac:: put_role)
214
- . authorize ( Action :: PutRoles )
237
+ . authorize ( Action :: PutUserRoles )
215
238
. wrap ( DisAllowRootUser ) ,
216
239
)
217
240
. route (
218
241
web:: get ( )
219
242
. to ( rbac:: get_role)
220
- . authorize_for_user ( Action :: GetRole ) ,
243
+ . authorize_for_user ( Action :: GetUserRoles ) ,
221
244
) ,
222
245
)
223
246
. service (
@@ -238,6 +261,24 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) {
238
261
. authorize ( Action :: QueryLLM ) ,
239
262
) ,
240
263
) ;
264
+ let role_api = web:: scope ( "/role" )
265
+ . service ( resource ( "" ) . route ( web:: get ( ) . to ( role:: list) . authorize ( Action :: ListRole ) ) )
266
+ . service (
267
+ resource ( "/{name}" )
268
+ . route ( web:: put ( ) . to ( role:: put) . authorize ( Action :: PutRole ) )
269
+ . route ( web:: delete ( ) . to ( role:: delete) . authorize ( Action :: DeleteRole ) )
270
+ . route ( web:: get ( ) . to ( role:: get) . authorize ( Action :: GetRole ) ) ,
271
+ ) ;
272
+
273
+ let mut oauth_api = web:: scope ( "/o" )
274
+ . service ( resource ( "/login" ) . route ( web:: get ( ) . to ( oidc:: login) ) )
275
+ . service ( resource ( "/logout" ) . route ( web:: get ( ) . to ( oidc:: logout) ) )
276
+ . service ( resource ( "/code" ) . route ( web:: get ( ) . to ( oidc:: reply_login) ) ) ;
277
+
278
+ if let Some ( client) = oidc_client {
279
+ info ! ( "Registered oidc client" ) ;
280
+ oauth_api = oauth_api. app_data ( web:: Data :: from ( client) )
281
+ }
241
282
242
283
// Deny request if username is same as the env variable P_USERNAME.
243
284
cfg. service (
@@ -280,7 +321,9 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) {
280
321
) ,
281
322
)
282
323
. service ( user_api)
283
- . service ( llm_query_api) ,
324
+ . service ( llm_query_api)
325
+ . service ( oauth_api)
326
+ . service ( role_api) ,
284
327
)
285
328
// GET "/" ==> Serve the static frontend directory
286
329
. service ( ResourceFiles :: new ( "/" , generated) . resolve_not_found_to_root ( ) ) ;
0 commit comments