Skip to content

Commit f44eb93

Browse files
authored
chore: simplify authorized sqlstore (#3496)
# What does this PR do? This PR is generated with AI and reviewed by me. Refactors the AuthorizedSqlStore class to store the access policy as an instance variable rather than passing it as a parameter to each method call. This simplifies the API. # Test Plan existing tests
1 parent d3600b9 commit f44eb93

File tree

7 files changed

+32
-37
lines changed

7 files changed

+32
-37
lines changed

llama_stack/providers/inline/files/localfs/files.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def initialize(self) -> None:
4444
storage_path.mkdir(parents=True, exist_ok=True)
4545

4646
# Initialize SQL store for metadata
47-
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
47+
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
4848
await self.sql_store.create_table(
4949
"openai_files",
5050
{
@@ -74,7 +74,7 @@ async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]:
7474
if not self.sql_store:
7575
raise RuntimeError("Files provider not initialized")
7676

77-
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
77+
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
7878
if not row:
7979
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
8080

@@ -150,7 +150,6 @@ async def openai_list_files(
150150

151151
paginated_result = await self.sql_store.fetch_all(
152152
table="openai_files",
153-
policy=self.policy,
154153
where=where_conditions if where_conditions else None,
155154
order_by=[("created_at", order.value)],
156155
cursor=("id", after) if after else None,

llama_stack/providers/remote/files/s3/files.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[st
137137
where: dict[str, str | dict] = {"id": file_id}
138138
if not return_expired:
139139
where["expires_at"] = {">": self._now()}
140-
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
140+
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
141141
raise ResourceNotFoundError(file_id, "File", "files.list()")
142142
return row
143143

@@ -164,7 +164,7 @@ async def initialize(self) -> None:
164164
self._client = _create_s3_client(self._config)
165165
await _create_bucket_if_not_exists(self._client, self._config)
166166

167-
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
167+
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
168168
await self._sql_store.create_table(
169169
"openai_files",
170170
{
@@ -268,7 +268,6 @@ async def openai_list_files(
268268

269269
paginated_result = await self.sql_store.fetch_all(
270270
table="openai_files",
271-
policy=self.policy,
272271
where=where_conditions,
273272
order_by=[("created_at", order.value)],
274273
cursor=("id", after) if after else None,

llama_stack/providers/utils/inference/inference_store.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454

5555
async def initialize(self):
5656
"""Create the necessary tables if they don't exist."""
57-
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
57+
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
5858
await self.sql_store.create_table(
5959
"chat_completions",
6060
{
@@ -202,7 +202,6 @@ async def list_chat_completions(
202202
order_by=[("created", order.value)],
203203
cursor=("id", after) if after else None,
204204
limit=limit,
205-
policy=self.policy,
206205
)
207206

208207
data = [
@@ -229,7 +228,6 @@ async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithI
229228
row = await self.sql_store.fetch_one(
230229
table="chat_completions",
231230
where={"id": completion_id},
232-
policy=self.policy,
233231
)
234232

235233
if not row:

llama_stack/providers/utils/responses/responses_store.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
2828
sql_store_config = SqliteSqlStoreConfig(
2929
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
3030
)
31-
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
32-
self.policy = policy
31+
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
3332

3433
async def initialize(self):
3534
"""Create the necessary tables if they don't exist."""
@@ -87,7 +86,6 @@ async def list_responses(
8786
order_by=[("created_at", order.value)],
8887
cursor=("id", after) if after else None,
8988
limit=limit,
90-
policy=self.policy,
9189
)
9290

9391
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
@@ -105,7 +103,6 @@ async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWit
105103
row = await self.sql_store.fetch_one(
106104
"openai_responses",
107105
where={"id": response_id},
108-
policy=self.policy,
109106
)
110107

111108
if not row:
@@ -116,7 +113,7 @@ async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWit
116113
return OpenAIResponseObjectWithInput(**row["response_object"])
117114

118115
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
119-
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
116+
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
120117
if not row:
121118
raise ValueError(f"Response with id {response_id} not found")
122119
await self.sql_store.delete("openai_responses", where={"id": response_id})

llama_stack/providers/utils/sqlstore/authorized_sqlstore.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ class AuthorizedSqlStore:
5353
access control policies, user attribute capture, and SQL filtering optimization.
5454
"""
5555

56-
def __init__(self, sql_store: SqlStore):
56+
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
5757
"""
5858
Initialize the authorization layer.
5959
6060
:param sql_store: Base SqlStore implementation to wrap
61+
:param policy: Access control policy to use for authorization
6162
"""
6263
self.sql_store = sql_store
64+
self.policy = policy
6365
self._detect_database_type()
6466
self._validate_sql_optimized_policy()
6567

@@ -117,14 +119,13 @@ async def insert(self, table: str, data: Mapping[str, Any]) -> None:
117119
async def fetch_all(
118120
self,
119121
table: str,
120-
policy: list[AccessRule],
121122
where: Mapping[str, Any] | None = None,
122123
limit: int | None = None,
123124
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
124125
cursor: tuple[str, str] | None = None,
125126
) -> PaginatedResponse:
126127
"""Fetch all rows with automatic access control filtering."""
127-
access_where = self._build_access_control_where_clause(policy)
128+
access_where = self._build_access_control_where_clause(self.policy)
128129
rows = await self.sql_store.fetch_all(
129130
table=table,
130131
where=where,
@@ -146,7 +147,7 @@ async def fetch_all(
146147
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
147148
)
148149

149-
if is_action_allowed(policy, Action.READ, sql_record, current_user):
150+
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
150151
filtered_rows.append(row)
151152

152153
return PaginatedResponse(
@@ -157,14 +158,12 @@ async def fetch_all(
157158
async def fetch_one(
158159
self,
159160
table: str,
160-
policy: list[AccessRule],
161161
where: Mapping[str, Any] | None = None,
162162
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
163163
) -> dict[str, Any] | None:
164164
"""Fetch one row with automatic access control checking."""
165165
results = await self.fetch_all(
166166
table=table,
167-
policy=policy,
168167
where=where,
169168
limit=1,
170169
order_by=order_by,

tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def authorized_store(backend_config):
5757
config = config_func()
5858

5959
base_sqlstore = sqlstore_impl(config)
60-
authorized_store = AuthorizedSqlStore(base_sqlstore)
60+
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
6161

6262
yield authorized_store
6363

@@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
106106
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
107107

108108
# Test fetching with no user - should not error on JSON comparison
109-
result = await authorized_store.fetch_all(table_name, policy=default_policy())
109+
result = await authorized_store.fetch_all(table_name)
110110
assert len(result.data) == 1
111111
assert result.data[0]["id"] == "1"
112112
assert result.data[0]["access_attributes"] is None
@@ -119,15 +119,15 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
119119
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
120120

121121
# Fetch all - admin should see both
122-
result = await authorized_store.fetch_all(table_name, policy=default_policy())
122+
result = await authorized_store.fetch_all(table_name)
123123
assert len(result.data) == 2
124124

125125
# Test with non-admin user
126126
regular_user = User("regular-user", {"roles": ["user"]})
127127
mock_get_authenticated_user.return_value = regular_user
128128

129129
# Should only see public record
130-
result = await authorized_store.fetch_all(table_name, policy=default_policy())
130+
result = await authorized_store.fetch_all(table_name)
131131
assert len(result.data) == 1
132132
assert result.data[0]["id"] == "1"
133133

@@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
156156

157157
# Now test with the multi-user who has both roles=admin and teams=dev
158158
mock_get_authenticated_user.return_value = multi_user
159-
result = await authorized_store.fetch_all(table_name, policy=default_policy())
159+
result = await authorized_store.fetch_all(table_name)
160160

161161
# Should see:
162162
# - public record (1) - no access_attributes
@@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
217217
),
218218
]
219219

220+
# Create a new authorized store with the owner-only policy
221+
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
222+
220223
# Test user1 access - should only see their own record
221224
mock_get_authenticated_user.return_value = user1
222-
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
225+
result = await owner_only_store.fetch_all(table_name)
223226
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
224227
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
225228

226229
# Test user2 access - should only see their own record
227230
mock_get_authenticated_user.return_value = user2
228-
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
231+
result = await owner_only_store.fetch_all(table_name)
229232
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
230233
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
231234

232235
# Test with anonymous user - should see no records
233236
mock_get_authenticated_user.return_value = None
234-
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
237+
result = await owner_only_store.fetch_all(table_name)
235238
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
236239

237240
finally:

tests/unit/utils/test_authorized_sqlstore.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
2626
db_path=tmp_dir + "/" + db_name,
2727
)
2828
)
29-
sqlstore = AuthorizedSqlStore(base_sqlstore)
29+
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
3030

3131
# Create table with access control
3232
await sqlstore.create_table(
@@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
5656
mock_get_authenticated_user.return_value = admin_user
5757

5858
# Admin should see both documents
59-
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
59+
result = await sqlstore.fetch_all("documents", where={"id": 1})
6060
assert len(result.data) == 1
6161
assert result.data[0]["title"] == "Admin Document"
6262

6363
# User should only see their document
6464
mock_get_authenticated_user.return_value = regular_user
6565

66-
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
66+
result = await sqlstore.fetch_all("documents", where={"id": 1})
6767
assert len(result.data) == 0
6868

69-
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
69+
result = await sqlstore.fetch_all("documents", where={"id": 2})
7070
assert len(result.data) == 1
7171
assert result.data[0]["title"] == "User Document"
7272

73-
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
73+
row = await sqlstore.fetch_one("documents", where={"id": 1})
7474
assert row is None
7575

76-
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
76+
row = await sqlstore.fetch_one("documents", where={"id": 2})
7777
assert row is not None
7878
assert row["title"] == "User Document"
7979

@@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
8888
db_path=tmp_dir + "/" + db_name,
8989
)
9090
)
91-
sqlstore = AuthorizedSqlStore(base_sqlstore)
91+
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
9292

9393
await sqlstore.create_table(
9494
table="resources",
@@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
144144
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
145145
mock_get_authenticated_user.return_value = user
146146

147-
sql_results = await sqlstore.fetch_all("resources", policy=policy)
147+
sql_results = await sqlstore.fetch_all("resources")
148148
sql_ids = {row["id"] for row in sql_results.data}
149149
policy_ids = set()
150150
for scenario in test_scenarios:
@@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
174174
db_path=tmp_dir + "/" + db_name,
175175
)
176176
)
177-
authorized_store = AuthorizedSqlStore(base_sqlstore)
177+
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
178178

179179
await authorized_store.create_table(
180180
table="user_data",

0 commit comments

Comments
 (0)