|
1 | 1 | import importlib |
| 2 | +import importlib.util |
2 | 3 | import json |
3 | 4 | import os |
4 | 5 | from datetime import datetime, timezone |
|
11 | 12 |
|
12 | 13 | from typing import List, Literal, Optional |
13 | 14 | from types import ModuleType |
| 15 | +from guardrails.hub.registry import get_registry_path |
14 | 16 | from packaging.utils import canonicalize_name # PEP 503 |
15 | 17 |
|
16 | 18 | from guardrails.logger import logger as guardrails_logger |
|
23 | 25 | from guardrails_hub_types import Manifest |
24 | 26 | from guardrails.cli.server.hub_client import get_validator_manifest |
25 | 27 | from guardrails.settings import settings |
| 28 | +from guardrails.types.validator_registry import ValidatorRegistry |
26 | 29 |
|
27 | 30 |
|
28 | 31 | json_format: Literal["json"] = "json" |
@@ -72,11 +75,6 @@ def detect_installer() -> str: |
72 | 75 | return "uv" |
73 | 76 | return "pip" |
74 | 77 |
|
75 | | - @staticmethod |
76 | | - def get_registry_path() -> Path: |
77 | | - """Return the project-level registry path.""" |
78 | | - return Path(os.getcwd()) / ".guardrails" / "hub_registry.json" |
79 | | - |
80 | 78 | @staticmethod |
81 | 79 | def get_manifest_and_site_packages(module_name: str) -> tuple[Manifest, str]: |
82 | 80 | module_manifest = get_validator_manifest(module_name) |
@@ -131,10 +129,28 @@ def get_validator_from_manifest(manifest: Manifest) -> ModuleType: |
131 | 129 | # Reload or import the module |
132 | 130 | return ValidatorPackageService.reload_module(import_line) |
133 | 131 |
|
| 132 | + @staticmethod |
| 133 | + def rewrite_stub_file(registry: ValidatorRegistry): |
| 134 | + stub_file = ( |
| 135 | + Path(ValidatorPackageService.get_site_packages_location()) |
| 136 | + / "guardrails" |
| 137 | + / "hub" |
| 138 | + / "__init__.pyi" |
| 139 | + ) |
| 140 | + |
| 141 | + import_statements = [] |
| 142 | + for v in registry.validators.values(): |
| 143 | + if v.exports and v.import_path and importlib.util.find_spec(v.import_path): |
| 144 | + import_statements.extend( |
| 145 | + [f"from {v.import_path} import {e} as {e}" for e in v.exports] |
| 146 | + ) |
| 147 | + |
| 148 | + stub_file.write_text("\n".join(import_statements)) |
| 149 | + |
134 | 150 | @staticmethod |
135 | 151 | def register_validator(manifest: Manifest): |
136 | 152 | """Register a validator in the project-level JSON registry.""" |
137 | | - registry_file = ValidatorPackageService.get_registry_path() |
| 153 | + registry_file = get_registry_path() |
138 | 154 | registry_file.parent.mkdir(parents=True, exist_ok=True) |
139 | 155 |
|
140 | 156 | registry = {"version": 1, "validators": {}} |
@@ -168,6 +184,35 @@ def register_validator(manifest: Manifest): |
168 | 184 |
|
169 | 185 | registry_file.write_text(json.dumps(registry, indent=2)) |
170 | 186 |
|
| 187 | + ValidatorPackageService.rewrite_stub_file( |
| 188 | + ValidatorRegistry.model_validate(registry) |
| 189 | + ) |
| 190 | + |
| 191 | + @staticmethod |
| 192 | + def unregister_validator(validator_id: str): |
| 193 | + """Remove a validator from the project-level JSON registry.""" |
| 194 | + registry_file = get_registry_path() |
| 195 | + if not registry_file.exists(): |
| 196 | + return |
| 197 | + |
| 198 | + try: |
| 199 | + registry = json.loads(registry_file.read_text()) |
| 200 | + except (json.JSONDecodeError, OSError): |
| 201 | + guardrails_logger.debug( |
| 202 | + "Registry at %s is unreadable; skipping unregister", |
| 203 | + registry_file, |
| 204 | + ) |
| 205 | + return |
| 206 | + |
| 207 | + validators = registry.get("validators", {}) |
| 208 | + if validator_id in validators: |
| 209 | + del validators[validator_id] |
| 210 | + registry["validators"] = validators |
| 211 | + registry_file.write_text(json.dumps(registry, indent=2)) |
| 212 | + ValidatorPackageService.rewrite_stub_file( |
| 213 | + ValidatorRegistry.model_validate(registry) |
| 214 | + ) |
| 215 | + |
171 | 216 | @staticmethod |
172 | 217 | def add_to_hub_inits(manifest: Manifest, site_packages: str): |
173 | 218 | validator_id = manifest.id |
|
0 commit comments