|
14 | 14 | from prettytable import PrettyTable |
15 | 15 | import matplotlib.pyplot as plt |
16 | 16 |
|
17 | | -weight_download_url="https://drive.google.com/file/d/1zZU3b6bLqTHLuxuFWt80wrfJsVROJuNP/view?usp=sharing" |
18 | | - |
19 | 17 | def collect_and_prepare(group_by_key: Group, search_parameters: dict[str, Any], images_per_group: int, output_directory: str): |
20 | 18 |
|
21 | 19 | groups: list[str] = search_parameters[group_by_key.value] |
@@ -108,22 +106,33 @@ def collect_and_prepare(group_by_key: Group, search_parameters: dict[str, Any], |
108 | 106 |
|
109 | 107 | __count_classes_and_output_table(output_directory, output_directory / 'class_idx.txt' ) |
110 | 108 |
|
111 | | -def __download_file_from_google_drive(url, destination): |
112 | | - if destination.exists(): |
113 | | - print(f"{destination} already exists. Skipping download.") |
114 | | - return |
115 | | - |
116 | | - print(f"{destination} does not exist. Downloading...") |
| 109 | +def __download_file_from_google_drive(drive_url, destination): |
| 110 | + # Extract the file ID from the Google Drive URL |
| 111 | + file_id = drive_url.split('/d/')[1].split('/')[0] |
| 112 | + URL = "https://drive.google.com/uc?export=download" |
117 | 113 |
|
118 | | - response = requests.get(url, stream=True) |
119 | | - if response.status_code == 200: |
120 | | - with open(destination, 'wb') as f: |
121 | | - for chunk in response.iter_content(8192): |
122 | | - if chunk: |
| 114 | + # Send a request to Google Drive to start the file download |
| 115 | + with requests.Session() as session: |
| 116 | + response = session.get(URL, params={'id': file_id}, stream=True) |
| 117 | + |
| 118 | + # Get confirmation token if required (for large files) |
| 119 | + token = None |
| 120 | + for key, value in response.cookies.items(): |
| 121 | + if key.startswith('download_warning'): |
| 122 | + token = value |
| 123 | + |
| 124 | + if token: |
| 125 | + # Reattempt download with confirmation token |
| 126 | + response = session.get(URL, params={'id': file_id, 'confirm': token}, stream=True) |
| 127 | + |
| 128 | + # Save the file content |
| 129 | + CHUNK_SIZE = 32768 |
| 130 | + with open(destination, "wb") as f: |
| 131 | + for chunk in response.iter_content(CHUNK_SIZE): |
| 132 | + if chunk: # Filter out keep-alive new chunks |
123 | 133 | f.write(chunk) |
| 134 | + |
124 | 135 | print(f"File downloaded successfully and saved at: {destination}") |
125 | | - else: |
126 | | - print(f"Failed to download file. Status code: {response.status_code}") |
127 | 136 |
|
128 | 137 | def _fetch_occurrences(group_key: str, group_value: str, parameters: dict[str, Any], totalLimit: int) -> list[dict[str, Any]]: |
129 | 138 | parameters[group_key] = group_value |
|
0 commit comments