Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions backend/open_webui/test/test_oauth_google_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
import aiohttp
from open_webui.utils.oauth import OAuthManager
from open_webui.config import AppConfig


class TestOAuthGoogleGroups:
"""Basic tests for Google OAuth Groups functionality"""

def setup_method(self):
"""Setup test fixtures"""
self.oauth_manager = OAuthManager(app=MagicMock())

@pytest.mark.asyncio
async def test_fetch_google_groups_success(self):
"""Test successful Google groups fetching with proper aiohttp mocking"""
# Mock response data from Google Cloud Identity API
mock_response_data = {
"memberships": [
{
"groupKey": {"id": "[email protected]"},
"group": "groups/123",
"displayName": "Admin Group"
},
{
"groupKey": {"id": "[email protected]"},
"group": "groups/456",
"displayName": "Users Group"
}
]
}

# Create properly structured async mocks
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_response_data)

# Mock the async context manager for session.get()
mock_get_context = MagicMock()
mock_get_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_get_context.__aexit__ = AsyncMock(return_value=None)

# Mock the session
mock_session = MagicMock()
mock_session.get = MagicMock(return_value=mock_get_context)

# Mock the async context manager for ClientSession
mock_session_context = MagicMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock(return_value=None)

with patch("aiohttp.ClientSession", return_value=mock_session_context):
groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity(
access_token="test_token",
user_email="[email protected]"
)

# Verify the results
assert groups == ["[email protected]", "[email protected]"]

# Verify the HTTP call was made correctly
mock_session.get.assert_called_once()
call_args = mock_session.get.call_args

# Check the URL contains the user email (URL encoded)
url_arg = call_args[0][0] # First positional argument
assert "user%40company.com" in url_arg # @ is encoded as %40
assert "searchTransitiveGroups" in url_arg

# Check headers contain the bearer token
headers_arg = call_args[1]["headers"] # headers keyword argument
assert headers_arg["Authorization"] == "Bearer test_token"
assert headers_arg["Content-Type"] == "application/json"

@pytest.mark.asyncio
async def test_fetch_google_groups_api_error(self):
"""Test handling of API errors when fetching groups"""
# Mock failed response
mock_response = MagicMock()
mock_response.status = 403
mock_response.text = AsyncMock(return_value="Permission denied")

# Mock the async context manager for session.get()
mock_get_context = MagicMock()
mock_get_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_get_context.__aexit__ = AsyncMock(return_value=None)

# Mock the session
mock_session = MagicMock()
mock_session.get = MagicMock(return_value=mock_get_context)

# Mock the async context manager for ClientSession
mock_session_context = MagicMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock(return_value=None)

with patch("aiohttp.ClientSession", return_value=mock_session_context):
groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity(
access_token="test_token",
user_email="[email protected]"
)

# Should return empty list on error
assert groups == []

@pytest.mark.asyncio
async def test_fetch_google_groups_network_error(self):
"""Test handling of network errors when fetching groups"""
# Mock the session that raises an exception when get() is called
mock_session = MagicMock()
mock_session.get.side_effect = aiohttp.ClientError("Network error")

# Mock the async context manager for ClientSession
mock_session_context = MagicMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock(return_value=None)

with patch("aiohttp.ClientSession", return_value=mock_session_context):
groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity(
access_token="test_token",
user_email="[email protected]"
)

# Should return empty list on network error
assert groups == []

@pytest.mark.asyncio
async def test_get_user_role_with_google_groups(self):
"""Test role assignment using Google groups"""
# Mock configuration
mock_config = MagicMock()
mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True
mock_config.OAUTH_ROLES_CLAIM = "groups"
mock_config.OAUTH_ALLOWED_ROLES = ["[email protected]"]
mock_config.OAUTH_ADMIN_ROLES = ["[email protected]"]
mock_config.DEFAULT_USER_ROLE = "pending"
mock_config.OAUTH_EMAIL_CLAIM = "email"

user_data = {"email": "[email protected]"}

# Mock Google OAuth scope check and Users class
with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
patch("open_webui.utils.oauth.GOOGLE_OAUTH_SCOPE") as mock_scope, \
patch("open_webui.utils.oauth.Users") as mock_users, \
patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch:

mock_scope.value = "openid email profile https://www.googleapis.com/auth/cloud-identity.groups.readonly"
mock_fetch.return_value = ["[email protected]", "[email protected]"]
mock_users.get_num_users.return_value = 5 # Not first user

role = await self.oauth_manager.get_user_role(
user=None,
user_data=user_data,
provider="google",
access_token="test_token"
)

# Should assign admin role since user is in admin group
assert role == "admin"
mock_fetch.assert_called_once_with("test_token", "[email protected]")

@pytest.mark.asyncio
async def test_get_user_role_fallback_to_claims(self):
"""Test fallback to traditional claims when Google groups fail"""
mock_config = MagicMock()
mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True
mock_config.OAUTH_ROLES_CLAIM = "groups"
mock_config.OAUTH_ALLOWED_ROLES = ["users"]
mock_config.OAUTH_ADMIN_ROLES = ["admin"]
mock_config.DEFAULT_USER_ROLE = "pending"
mock_config.OAUTH_EMAIL_CLAIM = "email"

user_data = {
"email": "[email protected]",
"groups": ["users"]
}

with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
patch("open_webui.utils.oauth.GOOGLE_OAUTH_SCOPE") as mock_scope, \
patch("open_webui.utils.oauth.Users") as mock_users, \
patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch:

# Mock scope without Cloud Identity
mock_scope.value = "openid email profile"
mock_users.get_num_users.return_value = 5 # Not first user

role = await self.oauth_manager.get_user_role(
user=None,
user_data=user_data,
provider="google",
access_token="test_token"
)

# Should use traditional claims since Cloud Identity scope not present
assert role == "user"
mock_fetch.assert_not_called()

@pytest.mark.asyncio
async def test_get_user_role_non_google_provider(self):
"""Test that non-Google providers use traditional claims"""
mock_config = MagicMock()
mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True
mock_config.OAUTH_ROLES_CLAIM = "roles"
mock_config.OAUTH_ALLOWED_ROLES = ["user"]
mock_config.OAUTH_ADMIN_ROLES = ["admin"]
mock_config.DEFAULT_USER_ROLE = "pending"

user_data = {"roles": ["user"]}

with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
patch("open_webui.utils.oauth.Users") as mock_users, \
patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch:

mock_users.get_num_users.return_value = 5 # Not first user

role = await self.oauth_manager.get_user_role(
user=None,
user_data=user_data,
provider="microsoft",
access_token="test_token"
)

# Should use traditional claims for non-Google providers
assert role == "user"
mock_fetch.assert_not_called()

@pytest.mark.asyncio
async def test_update_user_groups_with_google_groups(self):
"""Test group management using Google groups from user_data"""
mock_config = MagicMock()
mock_config.OAUTH_GROUPS_CLAIM = "groups"
mock_config.OAUTH_BLOCKED_GROUPS = "[]"
mock_config.ENABLE_OAUTH_GROUP_CREATION = False

# Mock user with Google groups data
mock_user = MagicMock()
mock_user.id = "user123"

user_data = {
"google_groups": ["[email protected]", "[email protected]"]
}

# Mock existing groups and user groups
mock_existing_group = MagicMock()
mock_existing_group.name = "[email protected]"
mock_existing_group.id = "group1"
mock_existing_group.user_ids = []
mock_existing_group.permissions = {"read": True}
mock_existing_group.description = "Developers group"

with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
patch("open_webui.utils.oauth.Groups") as mock_groups:

mock_groups.get_groups_by_member_id.return_value = []
mock_groups.get_groups.return_value = [mock_existing_group]

await self.oauth_manager.update_user_groups(
user=mock_user,
user_data=user_data,
default_permissions={"read": True}
)

# Should use Google groups instead of traditional claims
mock_groups.get_groups_by_member_id.assert_called_once_with("user123")
mock_groups.update_group_by_id.assert_called()
Loading
Loading