Skip to content

Commit b740616

Browse files
committed
Add workgroup global config. #437
1 parent 69850e5 commit b740616

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

awswrangler/_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class _ConfigArg(NamedTuple):
3030
"max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False),
3131
"max_cache_seconds": _ConfigArg(dtype=int, nullable=False),
3232
"s3_block_size": _ConfigArg(dtype=int, nullable=False, enforced=True),
33+
"workgroup": _ConfigArg(dtype=str, nullable=False, enforced=True),
3334
}
3435

3536

@@ -214,6 +215,15 @@ def s3_block_size(self) -> int:
214215
def s3_block_size(self, value: int) -> None:
215216
self._set_config_value(key="s3_block_size", value=value)
216217

218+
@property
219+
def workgroup(self) -> Optional[str]:
220+
"""Property workgroup."""
221+
return cast(Optional[str], self["workgroup"])
222+
223+
@workgroup.setter
224+
def workgroup(self, value: Optional[str]) -> None:
225+
self._set_config_value(key="workgroup", value=value)
226+
217227

218228
def _inject_config_doc(doc: Optional[str], available_configs: Tuple[str, ...]) -> str:
219229
if doc is None:

tests/test_config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
1010

1111

12-
def test_basics(path, glue_database, glue_table):
12+
def test_basics(path, glue_database, glue_table, workgroup0, workgroup1):
1313
args = {"table": glue_table, "path": "", "columns_types": {"col0": "bigint"}}
1414

1515
# Missing database argument
@@ -62,3 +62,12 @@ def test_basics(path, glue_database, glue_table):
6262
wr.catalog.does_table_exist(table=glue_table)
6363

6464
assert wr.config.to_pandas().shape == (len(wr._config._CONFIG_ARGS), 7)
65+
66+
# Workgroup
67+
wr.config.workgroup = workgroup0
68+
df = wr.athena.read_sql_query(sql="SELECT 1 as col0", database=glue_database)
69+
assert df.query_metadata["WorkGroup"] == workgroup0
70+
os.environ["WR_WORKGROUP"] = workgroup1
71+
wr.config.reset()
72+
df = wr.athena.read_sql_query(sql="SELECT 1 as col0", database=glue_database)
73+
assert df.query_metadata["WorkGroup"] == workgroup1

0 commit comments

Comments
 (0)