Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion litellm/integrations/gitlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.prompts.init_prompts import PromptSpec, PromptLiteLLMParams
from .gitlab_prompt_manager import GitLabPromptManager
from .gitlab_prompt_manager import GitLabPromptManager, GitLabPromptCache

# Global instances
global_gitlab_config: Optional[dict] = None
Expand Down Expand Up @@ -90,6 +90,7 @@ def _gitlab_prompt_initializer(
# Export public API
__all__ = [
"GitLabPromptManager",
"GitLabPromptCache",
"set_global_gitlab_config",
"global_gitlab_config",
]
137 changes: 136 additions & 1 deletion litellm/integrations/gitlab/gitlab_prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class GitLabPromptManager(CustomPromptManagement):
"access_token": "glpat_***",
"tag": "v1.2.3", # optional; takes precedence
"branch": "main", # default fallback
"prompts_path": "prompts/chat" # <--- NEW
"prompts_path": "prompts/chat"
}
"""

Expand Down Expand Up @@ -486,3 +486,138 @@ def get_chat_completion_prompt(
prompt_label,
prompt_version,
)


class GitLabPromptCache:
"""
Cache all .prompt files from a GitLab repo into memory.

- Keys are the *repo file paths* (e.g. "prompts/chat/greet/hi.prompt")
mapped to JSON-like dicts containing content + metadata.
- Also exposes a by-ID view (ID == path relative to prompts_path without ".prompt",
e.g. "greet/hi").

Usage:

cfg = {
"project": "group/subgroup/repo",
"access_token": "glpat_***",
"prompts_path": "prompts/chat", # optional, can be empty for repo root
# "branch": "main", # default is "main"
# "tag": "v1.2.3", # takes precedence over branch
# "base_url": "https://gitlab.com/api/v4" # default
}

cache = GitLabPromptCache(cfg)
cache.load_all() # fetch + parse all .prompt files

print(cache.list_files()) # repo file paths
print(cache.list_ids()) # template IDs relative to prompts_path

prompt_json = cache.get_by_file("prompts/chat/greet/hi.prompt")
prompt_json2 = cache.get_by_id("greet/hi")

# If GitLab content changes and you want to refresh:
cache.reload() # re-scan and refresh all
"""

def __init__(
self,
gitlab_config: Dict[str, Any],
*,
ref: Optional[str] = None,
gitlab_client: Optional[GitLabClient] = None,
) -> None:
# Build a PromptManager (which internally builds TemplateManager + Client)
self.prompt_manager = GitLabPromptManager(
gitlab_config=gitlab_config,
prompt_id=None,
ref=ref,
gitlab_client=gitlab_client,
)
self.template_manager: GitLabTemplateManager = self.prompt_manager.prompt_manager

# In-memory stores
self._by_file: Dict[str, Dict[str, Any]] = {}
self._by_id: Dict[str, Dict[str, Any]] = {}

# -------------------------
# Public API
# -------------------------

def load_all(self, *, recursive: bool = True) -> Dict[str, Dict[str, Any]]:
"""
Scan GitLab for all .prompt files under prompts_path, load and parse each,
and return the mapping of repo file path -> JSON-like dict.
"""
ids = self.template_manager.list_templates(recursive=recursive) # IDs relative to prompts_path
for pid in ids:
# Ensure template is loaded into TemplateManager
if pid not in self.template_manager.prompts:
self.template_manager._load_prompt_from_gitlab(pid)

tmpl = self.template_manager.get_template(pid)
if tmpl is None:
# If something raced/failed, try once more
self.template_manager._load_prompt_from_gitlab(pid)
tmpl = self.template_manager.get_template(pid)
if tmpl is None:
continue

file_path = self.template_manager._id_to_repo_path(pid) # "prompts/chat/..../file.prompt"
entry = self._template_to_json(pid, tmpl)

self._by_file[file_path] = entry
self._by_id[pid] = entry

return self._by_id

def reload(self, *, recursive: bool = True) -> Dict[str, Dict[str, Any]]:
"""Clear the cache and re-load from GitLab."""
self._by_file.clear()
self._by_id.clear()
return self.load_all(recursive=recursive)

def list_files(self) -> List[str]:
"""Return the repo file paths currently cached."""
return list(self._by_file.keys())

def list_ids(self) -> List[str]:
"""Return the template IDs (relative to prompts_path, without extension) currently cached."""
return list(self._by_id.keys())

def get_by_file(self, file_path: str) -> Optional[Dict[str, Any]]:
"""Get a cached prompt JSON by repo file path."""
return self._by_file.get(file_path)

def get_by_id(self, prompt_id: str) -> Optional[Dict[str, Any]]:
"""Get a cached prompt JSON by prompt ID (relative to prompts_path)."""
return self._by_id.get(prompt_id)

# -------------------------
# Internals
# -------------------------

def _template_to_json(self, prompt_id: str, tmpl: GitLabPromptTemplate) -> Dict[str, Any]:
"""
Normalize a GitLabPromptTemplate into a JSON-like dict that is easy to serialize.
"""
# Safer copy of metadata (avoid accidental mutation)
md = dict(tmpl.metadata or {})

# Pull standard fields (also present in metadata sometimes)
model = tmpl.model
temperature = tmpl.temperature
max_tokens = tmpl.max_tokens
optional_params = dict(tmpl.optional_params or {})

return {
"id": prompt_id, # e.g. "greet/hi"
"path": self.template_manager._id_to_repo_path(prompt_id), # e.g. "prompts/chat/greet/hi.prompt"
"content": tmpl.content, # rendered content (without frontmatter)
"metadata": md, # parsed frontmatter
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
"optional_params": optional_params,
}
33 changes: 26 additions & 7 deletions litellm/proxy/prompts/prompt_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def list_prompts(
```
"""
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
from litellm.proxy.prompts.prompt_registry import PROMPT_HUB

# check key metadata for prompts
key_metadata = user_api_key_dict.metadata
Expand All @@ -86,9 +86,9 @@ async def list_prompts(
if prompts is not None:
return ListPromptsResponse(
prompts=[
IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS[prompt]
PROMPT_HUB.IN_MEMORY_PROMPTS[prompt]
for prompt in prompts
if prompt in IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS
if prompt in PROMPT_HUB.IN_MEMORY_PROMPTS
]
)
# check if user is proxy admin - show all prompts
Expand All @@ -97,7 +97,7 @@ async def list_prompts(
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
):
return ListPromptsResponse(
prompts=list(IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS.values())
prompts=list(PROMPT_HUB.IN_MEMORY_PROMPTS.values())
)
else:
return ListPromptsResponse(prompts=[])
Expand Down Expand Up @@ -148,7 +148,7 @@ async def get_prompt_info(
}
```
"""
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
from litellm.proxy.prompts.prompt_registry import PROMPT_HUB

## CHECK IF USER HAS ACCESS TO PROMPT
prompts: Optional[List[str]] = None
Expand All @@ -169,14 +169,19 @@ async def get_prompt_info(
detail=f"You are not authorized to access this prompt. Your role - {user_api_key_dict.user_role}, Your key's prompts - {prompts}",
)

prompt_spec = IN_MEMORY_PROMPT_REGISTRY.get_prompt_by_id(prompt_id)
prompt_spec = PROMPT_HUB.get_prompt_by_id(prompt_id)
verbose_proxy_logger.debug(f"found prompt with id {prompt_id}-->{prompt_spec}")
if prompt_spec is None:
raise HTTPException(status_code=400, detail=f"Prompt {prompt_id} not found")

# Get prompt content from the callback
prompt_template: Optional[PromptTemplateBase] = None
try:
prompt_callback = IN_MEMORY_PROMPT_REGISTRY.get_prompt_callback_by_id(prompt_id)
prompt_callback = PROMPT_HUB.get_prompt_callback_by_id(prompt_id)
verbose_proxy_logger.debug(
f"Found the prompt callback for prompt id {prompt_id} --> {prompt_callback}"
)

if prompt_callback is not None:
# Extract content based on integration type
integration_name = prompt_callback.integration_name
Expand All @@ -196,6 +201,20 @@ async def get_prompt_info(
content=template[template_id]["content"],
metadata=template[template_id]["metadata"],
)
if integration_name == "gitlab":
from litellm.integrations.gitlab import (
GitLabPromptManager,
)
if isinstance(prompt_callback, GitLabPromptManager):
template = prompt_callback.prompt_manager.get_all_prompts_as_json()
if template is not None and len(template) == 1:
template_id = list(template.keys())[0]
prompt_template = PromptTemplateBase(
litellm_prompt_id=template_id, # id sent to prompt management tool
content=template[template_id]["content"],
metadata=template[template_id]["metadata"],
)


except Exception:
# If content extraction fails, continue without content
Expand Down
Loading
Loading