Skip to content

Commit fbe4aa4

Browse files
authored
feat: fetch and store more data about okta authorization server (#3894)
1 parent c205d2e commit fbe4aa4

File tree

13 files changed

+323
-83
lines changed

13 files changed

+323
-83
lines changed

lib/crewai/src/crewai/cli/authentication/main.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Any
2+
from typing import TYPE_CHECKING, Any, TypeVar, cast
33
import webbrowser
44

55
from pydantic import BaseModel, Field
@@ -13,6 +13,8 @@
1313

1414
console = Console()
1515

16+
TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
17+
1618

1719
class Oauth2Settings(BaseModel):
1820
provider: str = Field(
@@ -28,22 +30,36 @@ class Oauth2Settings(BaseModel):
2830
description="OAuth2 audience value, typically used to identify the target API or resource.",
2931
default=None,
3032
)
33+
extra: dict[str, Any] = Field(
34+
description="Extra configuration for the OAuth2 provider.",
35+
default={},
36+
)
3137

3238
@classmethod
33-
def from_settings(cls):
39+
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
40+
"""Create an Oauth2Settings instance from the CLI settings."""
41+
3442
settings = Settings()
3543

3644
return cls(
3745
provider=settings.oauth2_provider,
3846
domain=settings.oauth2_domain,
3947
client_id=settings.oauth2_client_id,
4048
audience=settings.oauth2_audience,
49+
extra=settings.oauth2_extra,
4150
)
4251

4352

53+
if TYPE_CHECKING:
54+
from crewai.cli.authentication.providers.base_provider import BaseProvider
55+
56+
4457
class ProviderFactory:
4558
@classmethod
46-
def from_settings(cls, settings: Oauth2Settings | None = None):
59+
def from_settings(
60+
cls: type["ProviderFactory"], # noqa: UP037
61+
settings: Oauth2Settings | None = None,
62+
) -> "BaseProvider": # noqa: UP037
4763
settings = settings or Oauth2Settings.from_settings()
4864

4965
import importlib
@@ -53,11 +69,11 @@ def from_settings(cls, settings: Oauth2Settings | None = None):
5369
)
5470
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
5571

56-
return provider(settings)
72+
return cast("BaseProvider", provider(settings))
5773

5874

5975
class AuthenticationCommand:
60-
def __init__(self):
76+
def __init__(self) -> None:
6177
self.token_manager = TokenManager()
6278
self.oauth2_provider = ProviderFactory.from_settings()
6379

@@ -84,7 +100,7 @@ def _get_device_code(self) -> dict[str, Any]:
84100
timeout=20,
85101
)
86102
response.raise_for_status()
87-
return response.json()
103+
return cast(dict[str, Any], response.json())
88104

89105
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
90106
"""Display the authentication instructions to the user."""

lib/crewai/src/crewai/cli/authentication/providers/base_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ def get_audience(self) -> str: ...
2424

2525
@abstractmethod
2626
def get_client_id(self) -> str: ...
27+
28+
def get_required_fields(self) -> list[str]:
29+
"""Returns which provider-specific fields inside the "extra" dict will be required"""
30+
return []

lib/crewai/src/crewai/cli/authentication/providers/okta.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33

44
class OktaProvider(BaseProvider):
55
def get_authorize_url(self) -> str:
6-
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
6+
return f"{self._oauth2_base_url()}/v1/device/authorize"
77

88
def get_token_url(self) -> str:
9-
return f"https://{self.settings.domain}/oauth2/default/v1/token"
9+
return f"{self._oauth2_base_url()}/v1/token"
1010

1111
def get_jwks_url(self) -> str:
12-
return f"https://{self.settings.domain}/oauth2/default/v1/keys"
12+
return f"{self._oauth2_base_url()}/v1/keys"
1313

1414
def get_issuer(self) -> str:
15-
return f"https://{self.settings.domain}/oauth2/default"
15+
return self._oauth2_base_url().removesuffix("/oauth2")
1616

1717
def get_audience(self) -> str:
1818
if self.settings.audience is None:
@@ -27,3 +27,16 @@ def get_client_id(self) -> str:
2727
"Client ID is required. Please set it in the configuration."
2828
)
2929
return self.settings.client_id
30+
31+
def get_required_fields(self) -> list[str]:
32+
return ["authorization_server_name", "using_org_auth_server"]
33+
34+
def _oauth2_base_url(self) -> str:
35+
using_org_auth_server = self.settings.extra.get("using_org_auth_server", False)
36+
37+
if using_org_auth_server:
38+
base_url = f"https://{self.settings.domain}/oauth2"
39+
else:
40+
base_url = f"https://{self.settings.domain}/oauth2/{self.settings.extra.get('authorization_server_name', 'default')}"
41+
42+
return f"{base_url}"

lib/crewai/src/crewai/cli/command.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@
1111

1212

1313
class BaseCommand:
14-
def __init__(self):
14+
def __init__(self) -> None:
1515
self._telemetry = Telemetry()
1616
self._telemetry.set_tracer()
1717

1818

1919
class PlusAPIMixin:
20-
def __init__(self, telemetry):
20+
def __init__(self, telemetry: Telemetry) -> None:
2121
try:
2222
telemetry.set_tracer()
2323
self.plus_api_client = PlusAPI(api_key=get_auth_token())
2424
except Exception:
25-
self._deploy_signup_error_span = telemetry.deploy_signup_error_span()
25+
telemetry.deploy_signup_error_span()
2626
console.print(
2727
"Please sign up/login to CrewAI+ before using the CLI.",
2828
style="bold red",

lib/crewai/src/crewai/cli/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from logging import getLogger
33
from pathlib import Path
44
import tempfile
5+
from typing import Any
56

67
from pydantic import BaseModel, Field
78

@@ -136,7 +137,12 @@ class Settings(BaseModel):
136137
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
137138
)
138139

139-
def __init__(self, config_path: Path | None = None, **data):
140+
oauth2_extra: dict[str, Any] = Field(
141+
description="Extra configuration for the OAuth2 provider.",
142+
default={},
143+
)
144+
145+
def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None:
140146
"""Load Settings from config path with fallback support"""
141147
if config_path is None:
142148
config_path = get_writable_config_path()

lib/crewai/src/crewai/cli/enterprise/main.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Any
1+
from typing import Any, cast
22

33
import requests
44
from requests.exceptions import JSONDecodeError, RequestException
55
from rich.console import Console
66

7+
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
78
from crewai.cli.command import BaseCommand
89
from crewai.cli.settings.main import SettingsCommand
910
from crewai.cli.version import get_crewai_version
@@ -13,7 +14,7 @@
1314

1415

1516
class EnterpriseConfigureCommand(BaseCommand):
16-
def __init__(self):
17+
def __init__(self) -> None:
1718
super().__init__()
1819
self.settings_command = SettingsCommand()
1920

@@ -54,25 +55,12 @@ def _fetch_oauth_config(self, enterprise_url: str) -> dict[str, Any]:
5455
except JSONDecodeError as e:
5556
raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e
5657

57-
required_fields = [
58-
"audience",
59-
"domain",
60-
"device_authorization_client_id",
61-
"provider",
62-
]
63-
missing_fields = [
64-
field for field in required_fields if field not in oauth_config
65-
]
66-
67-
if missing_fields:
68-
raise ValueError(
69-
f"Missing required fields in OAuth2 configuration: {', '.join(missing_fields)}"
70-
)
58+
self._validate_oauth_config(oauth_config)
7159

7260
console.print(
7361
"✅ Successfully retrieved OAuth2 configuration", style="green"
7462
)
75-
return oauth_config
63+
return cast(dict[str, Any], oauth_config)
7664

7765
except RequestException as e:
7866
raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
@@ -89,6 +77,7 @@ def _update_oauth_settings(
8977
"oauth2_audience": oauth_config["audience"],
9078
"oauth2_client_id": oauth_config["device_authorization_client_id"],
9179
"oauth2_domain": oauth_config["domain"],
80+
"oauth2_extra": oauth_config["extra"],
9281
}
9382

9483
console.print("🔄 Updating local OAuth2 configuration...")
@@ -99,3 +88,38 @@ def _update_oauth_settings(
9988

10089
except Exception as e:
10190
raise ValueError(f"Failed to update OAuth2 settings: {e!s}") from e
91+
92+
def _validate_oauth_config(self, oauth_config: dict[str, Any]) -> None:
93+
required_fields = [
94+
"audience",
95+
"domain",
96+
"device_authorization_client_id",
97+
"provider",
98+
"extra",
99+
]
100+
101+
missing_basic_fields = [
102+
field for field in required_fields if field not in oauth_config
103+
]
104+
missing_provider_specific_fields = [
105+
field
106+
for field in self._get_provider_specific_fields(oauth_config["provider"])
107+
if field not in oauth_config.get("extra", {})
108+
]
109+
110+
if missing_basic_fields:
111+
raise ValueError(
112+
f"Missing required fields in OAuth2 configuration: [{', '.join(missing_basic_fields)}]"
113+
)
114+
115+
if missing_provider_specific_fields:
116+
raise ValueError(
117+
f"Missing authentication provider required fields in OAuth2 configuration: [{', '.join(missing_provider_specific_fields)}] (Configured provider: '{oauth_config['provider']}')"
118+
)
119+
120+
def _get_provider_specific_fields(self, provider_name: str) -> list[str]:
121+
provider = ProviderFactory.from_settings(
122+
Oauth2Settings(provider=provider_name, client_id="dummy", domain="dummy")
123+
)
124+
125+
return provider.get_required_fields()

lib/crewai/src/crewai/cli/git.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
class Repository:
6-
def __init__(self, path="."):
6+
def __init__(self, path: str = ".") -> None:
77
self.path = path
88

99
if not self.is_git_installed():

lib/crewai/src/crewai/cli/plus_api.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Any
12
from urllib.parse import urljoin
23

34
import requests
@@ -36,19 +37,21 @@ def __init__(self, api_key: str) -> None:
3637
str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL
3738
)
3839

39-
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
40+
def _make_request(
41+
self, method: str, endpoint: str, **kwargs: Any
42+
) -> requests.Response:
4043
url = urljoin(self.base_url, endpoint)
4144
session = requests.Session()
4245
session.trust_env = False
4346
return session.request(method, url, headers=self.headers, **kwargs)
4447

45-
def login_to_tool_repository(self):
48+
def login_to_tool_repository(self) -> requests.Response:
4649
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
4750

48-
def get_tool(self, handle: str):
51+
def get_tool(self, handle: str) -> requests.Response:
4952
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
5053

51-
def get_agent(self, handle: str):
54+
def get_agent(self, handle: str) -> requests.Response:
5255
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
5356

5457
def publish_tool(
@@ -58,8 +61,8 @@ def publish_tool(
5861
version: str,
5962
description: str | None,
6063
encoded_file: str,
61-
available_exports: list[str] | None = None,
62-
):
64+
available_exports: list[dict[str, Any]] | None = None,
65+
) -> requests.Response:
6366
params = {
6467
"handle": handle,
6568
"public": is_public,
@@ -111,28 +114,32 @@ def delete_crew_by_uuid(self, uuid: str) -> requests.Response:
111114
def list_crews(self) -> requests.Response:
112115
return self._make_request("GET", self.CREWS_RESOURCE)
113116

114-
def create_crew(self, payload) -> requests.Response:
117+
def create_crew(self, payload: dict[str, Any]) -> requests.Response:
115118
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
116119

117120
def get_organizations(self) -> requests.Response:
118121
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
119122

120-
def initialize_trace_batch(self, payload) -> requests.Response:
123+
def initialize_trace_batch(self, payload: dict[str, Any]) -> requests.Response:
121124
return self._make_request(
122125
"POST",
123126
f"{self.TRACING_RESOURCE}/batches",
124127
json=payload,
125128
timeout=30,
126129
)
127130

128-
def initialize_ephemeral_trace_batch(self, payload) -> requests.Response:
131+
def initialize_ephemeral_trace_batch(
132+
self, payload: dict[str, Any]
133+
) -> requests.Response:
129134
return self._make_request(
130135
"POST",
131136
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
132137
json=payload,
133138
)
134139

135-
def send_trace_events(self, trace_batch_id: str, payload) -> requests.Response:
140+
def send_trace_events(
141+
self, trace_batch_id: str, payload: dict[str, Any]
142+
) -> requests.Response:
136143
return self._make_request(
137144
"POST",
138145
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
@@ -141,7 +148,7 @@ def send_trace_events(self, trace_batch_id: str, payload) -> requests.Response:
141148
)
142149

143150
def send_ephemeral_trace_events(
144-
self, trace_batch_id: str, payload
151+
self, trace_batch_id: str, payload: dict[str, Any]
145152
) -> requests.Response:
146153
return self._make_request(
147154
"POST",
@@ -150,7 +157,9 @@ def send_ephemeral_trace_events(
150157
timeout=30,
151158
)
152159

153-
def finalize_trace_batch(self, trace_batch_id: str, payload) -> requests.Response:
160+
def finalize_trace_batch(
161+
self, trace_batch_id: str, payload: dict[str, Any]
162+
) -> requests.Response:
154163
return self._make_request(
155164
"PATCH",
156165
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
@@ -159,7 +168,7 @@ def finalize_trace_batch(self, trace_batch_id: str, payload) -> requests.Respons
159168
)
160169

161170
def finalize_ephemeral_trace_batch(
162-
self, trace_batch_id: str, payload
171+
self, trace_batch_id: str, payload: dict[str, Any]
163172
) -> requests.Response:
164173
return self._make_request(
165174
"PATCH",

lib/crewai/src/crewai/cli/settings/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def list(self) -> None:
3434
current_value = getattr(self.settings, field_name)
3535
description = field_info.description or "No description available"
3636
display_value = (
37-
str(current_value) if current_value is not None else "Not set"
37+
str(current_value) if current_value not in [None, {}] else "Not set"
3838
)
3939

4040
table.add_row(field_name, display_value, description)

0 commit comments

Comments
 (0)