Skip to content

Commit 0a0e221

Browse files
committed
Credentials -> RC; add RC to settings
1 parent 759cc1f commit 0a0e221

File tree

17 files changed

+198
-146
lines changed

17 files changed

+198
-146
lines changed

guardrails/classes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from guardrails.classes.credentials import Credentials
2+
from guardrails.classes.rc import RC
23
from guardrails.classes.input_type import InputType
34
from guardrails.classes.output_type import OT
45
from guardrails.classes.validation.validation_result import (
@@ -11,6 +12,7 @@
1112

1213
__all__ = [
1314
"Credentials",
15+
"RC",
1416
"ErrorSpan",
1517
"InputType",
1618
"OT",

guardrails/classes/credentials.py

Lines changed: 21 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
import logging
2-
import os
32
from dataclasses import dataclass
4-
from os.path import expanduser
53
from typing import Optional
4+
from typing_extensions import deprecated
65

7-
from guardrails.classes.generic.serializeable import Serializeable
6+
from guardrails.classes.generic.serializeable import SerializeableJSONEncoder
7+
from guardrails.classes.rc import RC
88

99
BOOL_CONFIGS = set(["no_metrics", "enable_metrics", "use_remote_inferencing"])
1010

1111

12+
@deprecated(
13+
(
14+
"The `Credentials` class is deprecated and will be removed in version 0.6.x."
15+
" Use the `RC` class instead."
16+
),
17+
category=DeprecationWarning,
18+
)
1219
@dataclass
13-
class Credentials(Serializeable):
14-
id: Optional[str] = None
15-
token: Optional[str] = None
20+
class Credentials(RC):
1621
no_metrics: Optional[bool] = False
17-
enable_metrics: Optional[bool] = True
18-
use_remote_inferencing: Optional[bool] = True
1922

2023
@staticmethod
2124
def _to_bool(value: str) -> Optional[bool]:
@@ -27,51 +30,16 @@ def _to_bool(value: str) -> Optional[bool]:
2730

2831
@staticmethod
2932
def has_rc_file() -> bool:
30-
home = expanduser("~")
31-
guardrails_rc = os.path.join(home, ".guardrailsrc")
32-
return os.path.exists(guardrails_rc)
33+
return RC.exists()
3334

3435
@staticmethod
3536
def from_rc_file(logger: Optional[logging.Logger] = None) -> "Credentials":
36-
try:
37-
if not logger:
38-
logger = logging.getLogger()
39-
home = expanduser("~")
40-
guardrails_rc = os.path.join(home, ".guardrailsrc")
41-
with open(guardrails_rc, encoding="utf-8") as rc_file:
42-
lines = rc_file.readlines()
43-
filtered_lines = list(filter(lambda l: l.strip(), lines))
44-
creds = {}
45-
for line in filtered_lines:
46-
line_content = line.split("=", 1)
47-
if len(line_content) != 2:
48-
logger.warning(
49-
"""
50-
Invalid line found in .guardrailsrc file!
51-
All lines in this file should follow the format: key=value
52-
Ignoring line contents...
53-
"""
54-
)
55-
logger.debug(f".guardrailsrc file location: {guardrails_rc}")
56-
else:
57-
key, value = line_content
58-
key = key.strip()
59-
value = value.strip()
60-
if key in BOOL_CONFIGS:
61-
value = Credentials._to_bool(value)
62-
63-
creds[key] = value
64-
65-
rc_file.close()
66-
67-
# backfill no_metrics, handle defaults
68-
# remove in 0.5.0
69-
no_metrics_val = creds.pop("no_metrics", None)
70-
if no_metrics_val is not None and creds.get("enable_metrics") is None:
71-
creds["enable_metrics"] = not no_metrics_val
72-
73-
creds_dict = Credentials.from_dict(creds)
74-
return creds_dict
75-
76-
except FileNotFoundError:
77-
return Credentials.from_dict({}) # type: ignore
37+
rc = RC.load(logger)
38+
return Credentials(
39+
id=rc.id,
40+
token=rc.token,
41+
enable_metrics=rc.enable_metrics,
42+
use_remote_inferencing=rc.use_remote_inferencing,
43+
no_metrics=(not rc.enable_metrics),
44+
encoder=SerializeableJSONEncoder(),
45+
)

guardrails/classes/rc.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import logging
2+
import os
3+
from dataclasses import dataclass
4+
from os.path import expanduser
5+
from typing import Optional
6+
7+
from guardrails.classes.generic.serializeable import Serializeable
8+
from guardrails.utils.casting_utils import to_bool
9+
10+
BOOL_CONFIGS = set(["no_metrics", "enable_metrics", "use_remote_inferencing"])
11+
12+
13+
@dataclass
14+
class RC(Serializeable):
15+
id: Optional[str] = None
16+
token: Optional[str] = None
17+
enable_metrics: Optional[bool] = True
18+
use_remote_inferencing: Optional[bool] = True
19+
20+
@staticmethod
21+
def exists() -> bool:
22+
home = expanduser("~")
23+
guardrails_rc = os.path.join(home, ".guardrailsrc")
24+
return os.path.exists(guardrails_rc)
25+
26+
@classmethod
27+
def load(cls, logger: Optional[logging.Logger] = None) -> "RC":
28+
try:
29+
if not logger:
30+
logger = logging.getLogger()
31+
home = expanduser("~")
32+
guardrails_rc = os.path.join(home, ".guardrailsrc")
33+
with open(guardrails_rc, encoding="utf-8") as rc_file:
34+
lines = rc_file.readlines()
35+
filtered_lines = list(filter(lambda l: l.strip(), lines))
36+
config = {}
37+
for line in filtered_lines:
38+
line_content = line.split("=", 1)
39+
if len(line_content) != 2:
40+
logger.warning(
41+
"""
42+
Invalid line found in .guardrailsrc file!
43+
All lines in this file should follow the format: key=value
44+
Ignoring line contents...
45+
"""
46+
)
47+
logger.debug(f".guardrailsrc file location: {guardrails_rc}")
48+
else:
49+
key, value = line_content
50+
key = key.strip()
51+
value = value.strip()
52+
if key in BOOL_CONFIGS:
53+
value = to_bool(value)
54+
55+
config[key] = value
56+
57+
rc_file.close()
58+
59+
# backfill no_metrics, handle defaults
60+
# We missed this comment in the 0.5.0 release
61+
# Making it a TODO for 0.6.0
62+
# TODO: remove in 0.6.0
63+
no_metrics_val = config.pop("no_metrics", None)
64+
if no_metrics_val is not None and config.get("enable_metrics") is None:
65+
config["enable_metrics"] = not no_metrics_val
66+
del config["no_metrics"]
67+
68+
rc = cls.from_dict(config)
69+
return rc
70+
71+
except FileNotFoundError:
72+
return cls.from_dict({}) # type: ignore

guardrails/cli/configure.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import typer
88

9-
from guardrails.classes.credentials import Credentials
9+
from guardrails.settings import settings
1010
from guardrails.cli.guardrails import guardrails
1111
from guardrails.cli.logger import LEVELS, logger
1212
from guardrails.cli.hub.console import console
@@ -46,7 +46,7 @@ def save_configuration_file(
4646

4747
def _get_default_token() -> str:
4848
"""Get the default token from the configuration file."""
49-
file_token = Credentials.from_rc_file(logger).token
49+
file_token = settings.rc.token
5050
if file_token is None:
5151
return ""
5252
return file_token

guardrails/cli/server/auth.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

guardrails/cli/server/hub_client.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from jwt import ExpiredSignatureError, DecodeError
1010

1111

12-
from guardrails.classes.credentials import Credentials
12+
from guardrails.settings import settings
13+
from guardrails.classes.rc import RC
1314
from guardrails.cli.logger import logger
1415
from guardrails.version import GUARDRAILS_VERSION
1516

@@ -86,8 +87,8 @@ def fetch_module_manifest(
8687
return fetch(manifest_url, token, anonymousUserId)
8788

8889

89-
def get_jwt_token(creds: Credentials) -> Optional[str]:
90-
token = creds.token
90+
def get_jwt_token(rc: RC) -> Optional[str]:
91+
token = rc.token
9192

9293
# check for jwt expiration
9394
if token:
@@ -101,23 +102,21 @@ def get_jwt_token(creds: Credentials) -> Optional[str]:
101102

102103

103104
def fetch_module(module_name: str) -> Optional[Manifest]:
104-
creds = Credentials.from_rc_file(logger)
105-
token = get_jwt_token(creds)
105+
token = get_jwt_token(settings.rc)
106106

107-
module_manifest_json = fetch_module_manifest(module_name, token, creds.id)
107+
module_manifest_json = fetch_module_manifest(module_name, token, settings.rc.id)
108108
return Manifest.from_dict(module_manifest_json)
109109

110110

111111
def fetch_template(template_address: str) -> Dict[str, Any]:
112-
creds = Credentials.from_rc_file(logger)
113-
token = get_jwt_token(creds)
112+
token = get_jwt_token(settings.rc)
114113

115114
namespace, template_name = template_address.replace("hub:template://", "").split(
116115
"/", 1
117116
)
118117
template_path = f"guard-templates/{namespace}/{template_name}"
119118
template_url = f"{VALIDATOR_HUB_SERVICE}/{template_path}"
120-
return fetch(template_url, token, creds.id)
119+
return fetch(template_url, token, settings.rc.id)
121120

122121

123122
# GET /guard-templates/{namespace}/{guardTemplateName}
@@ -164,10 +163,9 @@ def get_validator_manifest(module_name: str):
164163
# GET /auth
165164
def get_auth():
166165
try:
167-
creds = Credentials.from_rc_file(logger)
168-
token = get_jwt_token(creds)
166+
token = get_jwt_token(settings.rc)
169167
auth_url = f"{VALIDATOR_HUB_SERVICE}/auth"
170-
response = fetch(auth_url, token, creds.id)
168+
response = fetch(auth_url, token, settings.rc.id)
171169
if not response:
172170
raise AuthenticationError("Failed to authenticate!")
173171
except HttpError as http_error:
@@ -182,8 +180,7 @@ def get_auth():
182180

183181
def post_validator_submit(package_name: str, content: str):
184182
try:
185-
creds = Credentials.from_rc_file(logger)
186-
token = get_jwt_token(creds)
183+
token = get_jwt_token(settings.rc)
187184
submission_url = f"{VALIDATOR_HUB_SERVICE}/validator/submit"
188185

189186
headers = {

guardrails/cli/telemetry.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,12 @@
11
import platform
2-
from typing import Optional
3-
from guardrails.classes.credentials import Credentials
2+
from guardrails.settings import settings
43
from guardrails.utils.hub_telemetry_utils import HubTelemetry
54
from guardrails.version import GUARDRAILS_VERSION
6-
from guardrails.cli.logger import logger
7-
8-
config: Optional[Credentials] = None
9-
10-
11-
def load_config_file() -> Credentials:
12-
global config
13-
if not config:
14-
config = Credentials.from_rc_file(logger)
15-
return config
165

176

187
def trace_if_enabled(command_name: str):
19-
config = load_config_file()
20-
if config.enable_metrics is True:
21-
telemetry = HubTelemetry()
8+
if settings.rc.enable_metrics is True:
9+
telemetry = HubTelemetry(enabled=True)
2210
telemetry.create_new_span(
2311
f"guardrails-cli/{command_name}",
2412
[

guardrails/guard.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232

3333
from guardrails.api_client import GuardrailsApiClient
3434
from guardrails.classes.output_type import OT
35+
from guardrails.classes.rc import RC
3536
from guardrails.classes.validation.validation_result import ErrorSpan
3637
from guardrails.classes.validation_outcome import ValidationOutcome
37-
from guardrails.classes.credentials import Credentials
3838
from guardrails.classes.execution import GuardExecutionOptions
3939
from guardrails.classes.generic import Stack
4040
from guardrails.classes.history import Call
@@ -259,6 +259,7 @@ def configure(
259259
self._set_num_reasks(num_reasks)
260260
if tracer:
261261
self._set_tracer(tracer)
262+
self._load_rc()
262263
self._configure_hub_telemtry(allow_metrics_collection)
263264

264265
def _set_num_reasks(self, num_reasks: Optional[int] = None) -> None:
@@ -285,24 +286,26 @@ def _set_tracer(self, tracer: Optional[Tracer] = None) -> None:
285286
set_tracer_context()
286287
self._tracer_context = get_tracer_context()
287288

289+
def _load_rc(self) -> None:
290+
rc = RC.load(logger)
291+
settings.rc = rc
292+
288293
def _configure_hub_telemtry(
289294
self, allow_metrics_collection: Optional[bool] = None
290295
) -> None:
291-
credentials = None
292-
if allow_metrics_collection is None:
293-
credentials = Credentials.from_rc_file(logger)
294-
# TODO: Check credentials.enable_metrics after merge from main
295-
allow_metrics_collection = credentials.enable_metrics is True
296+
allow_metrics_collection = (
297+
settings.rc.enable_metrics is True
298+
if allow_metrics_collection is None
299+
else allow_metrics_collection
300+
)
296301

297302
self._allow_metrics_collection = allow_metrics_collection
298303

299-
if allow_metrics_collection:
300-
if not credentials:
301-
credentials = Credentials.from_rc_file(logger)
302-
# Get unique id of user from credentials
303-
self._user_id = credentials.id or ""
304+
if allow_metrics_collection is True:
305+
# Get unique id of user from rc file
306+
self._user_id = settings.rc.id or ""
304307
# Initialize Hub Telemetry singleton and get the tracer
305-
self._hub_telemetry = HubTelemetry()
308+
self._hub_telemetry = HubTelemetry(enabled=True)
306309

307310
def _fill_validator_map(self):
308311
# dont init validators if were going to call the server
@@ -920,6 +923,7 @@ def _exec(
920923
call = runner(call_log=call_log, prompt_params=prompt_params)
921924
return ValidationOutcome[OT].from_guard_history(call)
922925

926+
# @trace(name="Guard.__call__")
923927
def __call__(
924928
self,
925929
llm_api: Optional[Callable] = None,

0 commit comments

Comments
 (0)