1616from cid .plugin import Plugin
1717from cid .utils import get_parameter , unset_parameter
1818from cid .helpers .account_map import AccountMap
19- from cid .helpers import Athena , CUR , Glue , QuickSight , Dashboard , Dataset
19+ from cid .helpers import Athena , CUR , Glue , QuickSight , Dashboard , Dataset , Datasource
2020from cid ._version import __version__
2121
2222
@@ -881,6 +881,7 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
881881 package_or_requirement = dataset_definition .get ('providedBy' ),
882882 resource_name = f'data/datasets/{ dataset_file } ' ,
883883 ).decode ('utf-8' ))
884+ cur_required = dataset_definition .get ('dependsOn' , dict ()).get ('cur' )
884885 athena_datasource = None
885886
886887
@@ -913,29 +914,17 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
913914 self .athena .DatabaseName = schemas [0 ]
914915 # else user will be suggested to choose database
915916 if len (datasources ) == 1 and datasources [0 ] in self .qs .athena_datasources :
916- athena_datasource = self .qs .get_datasources (id = datasources [0 ])
917+ athena_datasource = self .qs .get_datasources (id = datasources [0 ])[ 0 ]
917918 else :
918919 # FIXME: add user choice
919920 athena_datasource = next (iter (v for v in self .qs .athena_datasources .values ()))
920921 logger .info (f'Found { len (datasources )} Athena datasources, using the first one { athena_datasource .id } ' )
921- self .athena .WorkGroup = athena_datasource .AthenaParameters .get ('WorkGroup' )
922-
923- columns_tpl = {
924- 'athena_datasource_arn' : athena_datasource .arn if athena_datasource else None ,
925- 'athena_database_name' : self .athena .DatabaseName ,
926- 'user_arn' : self .qs .user .get ('Arn' )
927- }
928- if dataset_definition .get ('dependsOn' ).get ('cur' ):
929- columns_tpl ['cur_table_name' ] = self .cur .tableName
930-
931- compiled_dataset = json .loads (template .safe_substitute (columns_tpl ))
932- if dataset_id :
933- compiled_dataset .update ({'DataSetId' : dataset_id })
934-
922+ if isinstance (athena_datasource , Datasource ):
923+ self .athena .WorkGroup = athena_datasource .AthenaParameters .get ('WorkGroup' )
935924
936925 # Check for required views
937926 _views = dataset_definition .get ('dependsOn' ).get ('views' )
938- required_views = [(self .cur .tableName if name == '${cur_table_name}' else name ) for name in _views ]
927+ required_views = [(self .cur .tableName if cur_required and name == '${cur_table_name}' else name ) for name in _views ]
939928
940929 self .athena .discover_views (required_views )
941930 found_views = utils .intersection (required_views , self .athena ._metadata .keys ())
@@ -944,7 +933,7 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
944933 if recursive :
945934 print (f"Detected views: { ', ' .join (found_views )} " )
946935 for view_name in found_views :
947- if self . _clients . get ( 'cur' ) and view_name == self .cur .tableName :
936+ if cur_required and view_name == self .cur .tableName :
948937 logger .debug (f'Dependancy view { view_name } is a CUR. Skip.' )
949938 continue
950939 self .create_or_update_view (view_name , recursive = recursive , update = update )
@@ -955,6 +944,19 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
955944 for view_name in missing_views :
956945 self .create_or_update_view (view_name , recursive = recursive , update = update )
957946
947+ if not isinstance (athena_datasource , Datasource ): return False
948+ # Proceed only if all the parameters are set
949+ columns_tpl = {
950+ 'athena_datasource_arn' : athena_datasource .arn ,
951+ 'athena_database_name' : self .athena .DatabaseName ,
952+ 'cur_table_name' : self .cur .tableName if cur_required else None ,
953+ 'user_arn' : self .qs .user .get ('Arn' )
954+ }
955+
956+ compiled_dataset = json .loads (template .safe_substitute (columns_tpl ))
957+ if dataset_id :
958+ compiled_dataset .update ({'DataSetId' : dataset_id })
959+
958960 found_dataset = self .qs .describe_dataset (compiled_dataset .get ('DataSetId' ))
959961 if isinstance (found_dataset , Dataset ):
960962 if update :
0 commit comments