Skip to content

Commit e8a6445

Browse files
committed
Add role parameter
1 parent 7c66e94 commit e8a6445

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,7 @@ def test_set_role(run_trino):
11321132
assert_role_headers(cur, "system=ALL")
11331133

11341134

1135-
def test_set_role_in_connection(run_trino):
1135+
def test_set_roles_in_connection(run_trino):
11361136
_, host, port = run_trino
11371137

11381138
trino_connection = trino.dbapi.Connection(
@@ -1144,6 +1144,28 @@ def test_set_role_in_connection(run_trino):
11441144
assert_role_headers(cur, "system=ALL")
11451145

11461146

1147+
def test_set_role_in_connection(run_trino):
1148+
_, host, port = run_trino
1149+
1150+
trino_connection = trino.dbapi.Connection(
1151+
host=host, port=port, user="test", catalog="tpch", role="ALL"
1152+
)
1153+
cur = trino_connection.cursor()
1154+
cur.execute('SHOW TABLES FROM information_schema')
1155+
cur.fetchall()
1156+
assert_role_headers(cur, "system=ALL")
1157+
1158+
1159+
def test_set_role_and_roles_in_connection(run_trino):
1160+
_, host, port = run_trino
1161+
1162+
with pytest.raises(ValueError) as e:
1163+
trino.dbapi.Connection(
1164+
host=host, port=port, user="test", catalog="tpch", roles={"system": "ALL"}, role="ALL"
1165+
)
1166+
assert "specify 'role' or 'roles' parameter, but not both" == str(e.value)
1167+
1168+
11471169
def assert_role_headers(cursor, expected_header):
11481170
assert cursor._request.http_headers[constants.HEADER_ROLE] == expected_header
11491171

trino/client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __init__(
137137
client_tags: List[str] = None,
138138
roles: Dict[str, str] = None,
139139
timezone: str = None,
140+
role: str = None,
140141
):
141142
self._user = user
142143
self._catalog = catalog
@@ -147,7 +148,7 @@ def __init__(
147148
self._transaction_id = transaction_id
148149
self._extra_credential = extra_credential
149150
self._client_tags = client_tags.copy() if client_tags is not None else list()
150-
self._roles = self._format_roles(roles) if roles is not None else {}
151+
self._roles = self._format_roles(role, roles)
151152
self._prepared_statements: Dict[str, str] = {}
152153
self._object_lock = threading.Lock()
153154
self._timezone = timezone or get_localzone_name()
@@ -238,7 +239,13 @@ def timezone(self):
238239
with self._object_lock:
239240
return self._timezone
240241

241-
def _format_roles(self, roles):
242+
def _format_roles(self, role, roles):
243+
if role and roles:
244+
raise ValueError("specify 'role' or 'roles' parameter, but not both")
245+
elif role:
246+
roles = {"system": role}
247+
elif role is None and roles is None:
248+
return {}
242249
formatted_roles = {}
243250
for catalog, role in roles.items():
244251
is_legacy_role_pattern = ROLE_PATTERN.match(role) is not None

trino/dbapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
legacy_primitive_types=False,
114114
roles=None,
115115
timezone=None,
116+
role=None,
116117
):
117118
self.host = host
118119
self.port = port
@@ -133,6 +134,7 @@ def __init__(
133134
client_tags=client_tags,
134135
roles=roles,
135136
timezone=timezone,
137+
role=role,
136138
)
137139
# mypy cannot follow module import
138140
if http_session is None:

0 commit comments

Comments
 (0)