Skip to content

Commit 105761c

Browse files
committed
Improve dataset creation
1 parent 2e077e8 commit 105761c

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

cid/common.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from cid.plugin import Plugin
1717
from cid.utils import get_parameter, unset_parameter
1818
from cid.helpers.account_map import AccountMap
19-
from cid.helpers import Athena, CUR, Glue, QuickSight, Dashboard, Dataset
19+
from cid.helpers import Athena, CUR, Glue, QuickSight, Dashboard, Dataset, Datasource
2020
from cid._version import __version__
2121

2222

@@ -881,6 +881,7 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
881881
package_or_requirement=dataset_definition.get('providedBy'),
882882
resource_name=f'data/datasets/{dataset_file}',
883883
).decode('utf-8'))
884+
cur_required = dataset_definition.get('dependsOn', dict()).get('cur')
884885
athena_datasource = None
885886

886887

@@ -913,29 +914,17 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
913914
self.athena.DatabaseName = schemas[0]
914915
# else user will be suggested to choose database
915916
if len(datasources) == 1 and datasources[0] in self.qs.athena_datasources:
916-
athena_datasource = self.qs.get_datasources(id=datasources[0])
917+
athena_datasource = self.qs.get_datasources(id=datasources[0])[0]
917918
else:
918919
# FIXME: add user choice
919920
athena_datasource = next(iter(v for v in self.qs.athena_datasources.values()))
920921
logger.info(f'Found {len(datasources)} Athena datasources, using the first one {athena_datasource.id}')
921-
self.athena.WorkGroup = athena_datasource.AthenaParameters.get('WorkGroup')
922-
923-
columns_tpl = {
924-
'athena_datasource_arn': athena_datasource.arn if athena_datasource else None,
925-
'athena_database_name': self.athena.DatabaseName,
926-
'user_arn': self.qs.user.get('Arn')
927-
}
928-
if dataset_definition.get('dependsOn').get('cur'):
929-
columns_tpl['cur_table_name'] = self.cur.tableName
930-
931-
compiled_dataset = json.loads(template.safe_substitute(columns_tpl))
932-
if dataset_id:
933-
compiled_dataset.update({'DataSetId': dataset_id})
934-
922+
if isinstance(athena_datasource, Datasource):
923+
self.athena.WorkGroup = athena_datasource.AthenaParameters.get('WorkGroup')
935924

936925
# Check for required views
937926
_views = dataset_definition.get('dependsOn').get('views')
938-
required_views = [(self.cur.tableName if name =='${cur_table_name}' else name) for name in _views]
927+
required_views = [(self.cur.tableName if cur_required and name =='${cur_table_name}' else name) for name in _views]
939928

940929
self.athena.discover_views(required_views)
941930
found_views = utils.intersection(required_views, self.athena._metadata.keys())
@@ -944,7 +933,7 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
944933
if recursive:
945934
print(f"Detected views: {', '.join(found_views)}")
946935
for view_name in found_views:
947-
if self._clients.get('cur') and view_name == self.cur.tableName:
936+
if cur_required and view_name == self.cur.tableName:
948937
logger.debug(f'Dependancy view {view_name} is a CUR. Skip.')
949938
continue
950939
self.create_or_update_view(view_name, recursive=recursive, update=update)
@@ -955,6 +944,19 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
955944
for view_name in missing_views:
956945
self.create_or_update_view(view_name, recursive=recursive, update=update)
957946

947+
if not isinstance(athena_datasource, Datasource): return False
948+
# Proceed only if all the parameters are set
949+
columns_tpl = {
950+
'athena_datasource_arn': athena_datasource.arn,
951+
'athena_database_name': self.athena.DatabaseName,
952+
'cur_table_name': self.cur.tableName if cur_required else None,
953+
'user_arn': self.qs.user.get('Arn')
954+
}
955+
956+
compiled_dataset = json.loads(template.safe_substitute(columns_tpl))
957+
if dataset_id:
958+
compiled_dataset.update({'DataSetId': dataset_id})
959+
958960
found_dataset = self.qs.describe_dataset(compiled_dataset.get('DataSetId'))
959961
if isinstance(found_dataset, Dataset):
960962
if update:

cid/helpers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from cid.helpers.athena import Athena
22
from cid.helpers.cur import CUR
33
from cid.helpers.glue import Glue
4-
from cid.helpers.quicksight import QuickSight, Dashboard, Dataset
4+
from cid.helpers.quicksight import QuickSight, Dashboard, Dataset, Datasource
55

66

7-
__all__ = ["Athena", "CUR", "Glue", "QuickSight", "Dashboard", "Dataset"]
7+
__all__ = ["Athena", "CUR", "Glue", "QuickSight", "Dashboard", "Dataset", "Datasource"]

0 commit comments

Comments
 (0)