Skip to content

Commit 026e36f

Browse files
committed
Add role parameter
1 parent 7c66e94 commit 026e36f

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,18 @@ conn = trino.dbapi.connect(
359359
)
360360
```
361361

362+
You could also pass `system` role without explicitly specifing "system" catalog:
363+
364+
```python
365+
import trino
366+
conn = trino.dbapi.connect(
367+
host='localhost',
368+
port=443,
369+
user='the-user',
370+
roles="role1" # equivalent to {"system": "role1"}
371+
)
372+
```
373+
362374
## Timezone
363375

364376
The time zone for the session can be explicitly set using the IANA time zone

tests/integration/test_dbapi_integration.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,18 @@ def test_set_role_in_connection(run_trino):
11441144
assert_role_headers(cur, "system=ALL")
11451145

11461146

1147+
def test_set_system_role_in_connection(run_trino):
1148+
_, host, port = run_trino
1149+
1150+
trino_connection = trino.dbapi.Connection(
1151+
host=host, port=port, user="test", catalog="tpch", roles="ALL"
1152+
)
1153+
cur = trino_connection.cursor()
1154+
cur.execute('SHOW TABLES FROM information_schema')
1155+
cur.fetchall()
1156+
assert_role_headers(cur, "system=ALL")
1157+
1158+
11471159
def assert_role_headers(cursor, expected_header):
11481160
assert cursor._request.http_headers[constants.HEADER_ROLE] == expected_header
11491161

trino/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
transaction_id: str = None,
136136
extra_credential: List[Tuple[str, str]] = None,
137137
client_tags: List[str] = None,
138-
roles: Dict[str, str] = None,
138+
roles: Union[Dict[str, str], str] = None,
139139
timezone: str = None,
140140
):
141141
self._user = user
@@ -239,6 +239,8 @@ def timezone(self):
239239
return self._timezone
240240

241241
def _format_roles(self, roles):
242+
if isinstance(roles, str):
243+
roles = {"system": roles}
242244
formatted_roles = {}
243245
for catalog, role in roles.items():
244246
is_legacy_role_pattern = ROLE_PATTERN.match(role) is not None

0 commit comments

Comments
 (0)