1
1
import os
2
2
import time
3
- import types
4
3
from asyncio import get_event_loop_policy
5
4
from functools import partial
6
5
from typing import TYPE_CHECKING , Optional
19
18
from traitlets import Integer , List , Type , Unicode
20
19
from traitlets .config import Config
21
20
21
+ from .secrets .secrets_rest_api import SecretsRestAPI
22
+ from .secrets .secrets_manager import EnvSecretsManager
22
23
from .completions .handlers import DefaultInlineCompletionHandler
23
24
from .config_manager import ConfigManager
24
25
from .handlers import (
25
- ApiKeysHandler ,
26
26
GlobalConfigHandler ,
27
27
InterruptStreamingHandler ,
28
28
)
58
58
class AiExtension (ExtensionApp ):
59
59
name = "jupyter_ai"
60
60
handlers = [ # type:ignore[assignment]
61
- (r"api/ai/api_keys/(?P<api_key_name>\w+)/?" , ApiKeysHandler ),
62
61
(r"api/ai/config/?" , GlobalConfigHandler ),
63
62
(r"api/ai/chats/stop_streaming/?" , InterruptStreamingHandler ),
64
63
(r"api/ai/completion/inline/?" , DefaultInlineCompletionHandler ),
65
64
(r"api/ai/models/chat/?" , ChatModelEndpoint ),
65
+ (r"api/ai/secrets/?" , SecretsRestAPI ),
66
66
(
67
67
r"api/ai/static/jupyternaut.svg()/?" ,
68
68
StaticFileHandler ,
@@ -295,20 +295,12 @@ def on_change(
295
295
def initialize_settings (self ):
296
296
start = time .time ()
297
297
298
- # Read from allowlist and blocklist
299
- restrictions = {
300
- "allowed_providers" : self .allowed_providers ,
301
- "blocked_providers" : self .blocked_providers ,
302
- }
303
- self .settings ["allowed_models" ] = self .allowed_models
304
- self .settings ["blocked_models" ] = self .blocked_models
298
+ # Log traitlets configuration
305
299
self .log .info (f"Configured provider allowlist: { self .allowed_providers } " )
306
300
self .log .info (f"Configured provider blocklist: { self .blocked_providers } " )
307
301
self .log .info (f"Configured model allowlist: { self .allowed_models } " )
308
302
self .log .info (f"Configured model blocklist: { self .blocked_models } " )
309
- self .settings ["model_parameters" ] = self .model_parameters
310
303
self .log .info (f"Configured model parameters: { self .model_parameters } " )
311
-
312
304
defaults = {
313
305
"model_provider_id" : self .default_language_model ,
314
306
"embeddings_provider_id" : self .default_embeddings_model ,
@@ -319,8 +311,8 @@ def initialize_settings(self):
319
311
"completions_fields" : self .model_parameters ,
320
312
}
321
313
314
+ # Initialize ConfigManager
322
315
self .settings ["jai_config_manager" ] = ConfigManager (
323
- # traitlets configuration, not JAI configuration.
324
316
config = self .config ,
325
317
log = self .log ,
326
318
allowed_providers = self .allowed_providers ,
@@ -330,16 +322,21 @@ def initialize_settings(self):
330
322
defaults = defaults ,
331
323
)
332
324
333
- self .log .info (f"Registered { self .name } server extension" )
325
+ # Initialize SecretsManager
326
+ self .settings ["jai_secrets_manager" ] = EnvSecretsManager (parent = self )
334
327
328
+ # Bind event loop to settings dictionary
335
329
self .settings ["jai_event_loop" ] = self .event_loop
336
330
337
- # Create empty dictionary for events communicating that
338
- # message generation/streaming got interrupted.
331
+ # Bind dictionary of interrupts to settings dictionary.
332
+ # Each key is a message ID, each value is an asyncio.Event.
333
+ # When a message's interrupt event is set, the response is halted.
339
334
self .settings ["jai_message_interrupted" ] = {}
340
335
341
- latency_ms = round ((time .time () - start ) * 1000 )
342
- self .log .info (f"Initialized Jupyter AI server extension in { latency_ms } ms." )
336
+ # Log server extension startup time
337
+ self .log .info (f"Registered { self .name } server extension" )
338
+ startup_time = round ((time .time () - start ) * 1000 )
339
+ self .log .info (f"Initialized Jupyter AI server extension in { startup_time } ms." )
343
340
344
341
async def stop_extension (self ):
345
342
"""
@@ -359,7 +356,10 @@ async def _stop_extension(self):
359
356
Private method that defines the cleanup code to run when the server is
360
357
stopping.
361
358
"""
362
- # TODO: explore if cleanup is necessary
359
+ secrets_manager = self .settings .get ("jai_secrets_manager" , None )
360
+
361
+ if secrets_manager :
362
+ secrets_manager .stop ()
363
363
364
364
def _init_persona_manager (
365
365
self , room_id : str , ychat : YChat
@@ -428,7 +428,6 @@ def _link_jupyter_server_extension(self, server_app: ServerApp):
428
428
".git" , # Git version control directory
429
429
".venv" , # Python virtual environment directory
430
430
"venv" , # Python virtual environment directory
431
- ".env" , # Environment variable files
432
431
"node_modules" , # Node.js dependencies directory
433
432
".pytest_cache" , # PyTest cache directory
434
433
".mypy_cache" , # MyPy type checker cache directory
0 commit comments