88import tempfile
99from typing import Union
1010
11- import cosmotech_api
1211from azure .digitaltwins .core import DigitalTwinsClient
1312from azure .identity import DefaultAzureCredential
1413from cosmotech_api import DatasetApi
14+ from cosmotech_api import DatasetTwinGraphQuery
1515from cosmotech_api import ScenarioApi
16+ from cosmotech_api import TwinGraphQuery
1617from cosmotech_api import TwingraphApi
1718from cosmotech_api import WorkspaceApi
18- from cosmotech_api import DatasetTwinGraphQuery
19- from cosmotech_api import TwinGraphQuery
2019from openpyxl import load_workbook
2120
2221from cosmotech .coal .cosmotech_api .connection import get_api_client
@@ -85,7 +84,6 @@ def __init__(
8584 self .credentials = DefaultAzureCredential ()
8685 else :
8786 self .credentials = None
88-
8987
9088 self .workspace_id = workspace_id
9189 self .organization_id = organization_id
@@ -104,19 +102,17 @@ def get_scenario_data(self, scenario_id: str):
104102 def download_dataset (self , dataset_id : str ) -> (str , str , Union [str , None ]):
105103 with get_api_client ()[0 ] as api_client :
106104 api_instance = DatasetApi (api_client )
107-
108105 dataset = api_instance .find_dataset_by_id (
109106 organization_id = self .organization_id ,
110107 dataset_id = dataset_id )
111108 if dataset .connector is None :
112109 parameters = []
113110 else :
114111 parameters = dataset .connector .parameters_values
115-
116112 is_adt = 'AZURE_DIGITAL_TWINS_URL' in parameters
117113 is_storage = 'AZURE_STORAGE_CONTAINER_BLOB_PREFIX' in parameters
118114 is_legacy_twin_cache = 'TWIN_CACHE_NAME' in parameters and dataset .twingraph_id is None # Legacy twingraph dataset with specific connector
119- is_in_workspace_file = 'workspaceFile' in dataset .tags
115+ is_in_workspace_file = False if dataset . tags is None else 'workspaceFile' in dataset .tags
120116
121117 if is_adt :
122118 return {
@@ -329,7 +325,7 @@ def get_all_datasets(self, scenario_id: str) -> dict:
329325 dataset_ids = datasets [:]
330326
331327 for parameter in scenario_data .parameters_values :
332- if parameter .var_type == '%DATASETID%' :
328+ if parameter .var_type == '%DATASETID%' and parameter . value :
333329 dataset_id = parameter .value
334330 dataset_ids .append (dataset_id )
335331
@@ -344,7 +340,7 @@ def download_dataset_process(_dataset_id, _return_dict, _error_dict):
344340 _error_dict [_dataset_id ] = f'{ type (e ).__name__ } : { str (e )} '
345341 raise e
346342
347- if self .parallel :
343+ if self .parallel and len ( dataset_ids ) > 1 :
348344 manager = multiprocessing .Manager ()
349345 return_dict = manager .dict ()
350346 error_dict = manager .dict ()
0 commit comments