-
Notifications
You must be signed in to change notification settings - Fork 36
Huggingface #96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Huggingface #96
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import asdict, dataclass | ||
|
|
@@ -22,9 +23,80 @@ | |
| from ..predictions.tabular_predictions import TabularModelPredictions | ||
| from ..repository.evaluation_repository import EvaluationRepository | ||
| from ..utils import catchtime | ||
| from ..utils.huggingfacehub_utils import download_from_huggingface | ||
| from ..utils.download import download_files | ||
|
|
||
|
|
||
| def download_from_s3(name: str, include_zs: bool, exists: str, dry_run: bool, s3_download_map, benchmark_paths, verbose: bool): | ||
| print(f'Downloading files for {name} context... ' | ||
| f'(include_zs={include_zs}, exists="{exists}", dry_run={dry_run})') | ||
| if dry_run: | ||
| print(f'\tNOTE: `dry_run=True`! Files will not be downloaded.') | ||
| assert exists in ["raise", "ignore", "overwrite"] | ||
| assert s3_download_map is not None, \ | ||
| f'self.s3_download_map is None: download functionality is disabled' | ||
| file_paths_expected = benchmark_paths.get_file_paths(include_zs=include_zs) | ||
|
|
||
| file_paths_to_download = [f for f in file_paths_expected if f in s3_download_map] | ||
| if len(file_paths_to_download) == 0: | ||
| print(f'WARNING: Matching file paths to download is 0! ' | ||
| f'`self.s3_download_map` probably has incorrect keys.') | ||
| file_paths_already_exist = [f for f in file_paths_to_download if benchmark_paths.exists(f)] | ||
| file_paths_missing = [f for f in file_paths_to_download if not benchmark_paths.exists(f)] | ||
|
|
||
| if exists == 'raise': | ||
| if file_paths_already_exist: | ||
| raise AssertionError(f'`exists="{exists}"`, ' | ||
| f'and found {len(file_paths_already_exist)} files that already exist locally!\n' | ||
| f'\tExisting Files: {file_paths_already_exist}\n' | ||
| f'\tMissing Files: {file_paths_missing}\n' | ||
| f'Either manually inspect and delete existing files, ' | ||
| f'set `exists="ignore"` to keep your local files and only download missing files, ' | ||
| f'or set `exists="overwrite"` to overwrite your existing local files.') | ||
| elif exists == 'ignore': | ||
| file_paths_to_download = file_paths_missing | ||
| elif exists == 'overwrite': | ||
| file_paths_to_download = file_paths_to_download | ||
| else: | ||
| raise ValueError(f'Invalid value for exists (`exists="{exists}"`). ' | ||
| f'Valid values: {["raise", "ignore", "overwrite"]}') | ||
|
|
||
| s3_to_local_tuple_list = [(val, key) for key, val in s3_download_map.items() | ||
| if key in file_paths_to_download] | ||
|
|
||
| log_extra = '' | ||
|
|
||
| num_exist = len(file_paths_already_exist) | ||
| if exists == 'overwrite': | ||
| if num_exist > 0: | ||
| log_extra += f'\tWill overwrite {num_exist} files that exist locally:\n' \ | ||
| f'\t\t{file_paths_already_exist}' | ||
| else: | ||
| log_extra = f'' | ||
| if exists == 'ignore': | ||
| log_extra += f'\tWill skip {num_exist} files that exist locally:\n' \ | ||
| f'\t\t{file_paths_already_exist}' | ||
| if file_paths_missing: | ||
| if log_extra: | ||
| log_extra += '\n' | ||
| log_extra += f'Will download {len(file_paths_missing)} files that are missing locally:\n' \ | ||
| f'\t\t{file_paths_missing}' | ||
|
|
||
| if log_extra: | ||
| print(log_extra) | ||
| print(f'\tDownloading {len(s3_to_local_tuple_list)} files from s3 to local...') | ||
| for s3_path, local_path in s3_to_local_tuple_list: | ||
| print(f'\t\t"{s3_path}" -> "{local_path}"') | ||
| s3_required_list = [(s3_path, local_path) for s3_path, local_path in s3_to_local_tuple_list if | ||
| s3_path[:2] == "s3"] | ||
| urllib_required_list = [(s3_path, local_path) for s3_path, local_path in s3_to_local_tuple_list if | ||
| s3_path[:2] != "s3"] | ||
| if urllib_required_list: | ||
| download_files(remote_to_local_tuple_list=urllib_required_list, dry_run=dry_run, verbose=verbose) | ||
| if s3_required_list: | ||
| download_s3_files(s3_to_local_tuple_list=s3_required_list, dry_run=dry_run, verbose=verbose) | ||
|
|
||
|
|
||
| @dataclass | ||
| class BenchmarkPaths: | ||
| configs: str | ||
|
|
@@ -260,7 +332,9 @@ def download(self, | |
| include_zs: bool = True, | ||
| exists: str = 'raise', | ||
| verbose: bool = True, | ||
| dry_run: bool = False): | ||
| dry_run: bool = False, | ||
| use_s3: bool = True, | ||
| ): | ||
| """ | ||
| Downloads all BenchmarkContext required files from s3 to local disk. | ||
|
|
||
|
|
@@ -275,78 +349,27 @@ def download(self, | |
| Guarantees alignment between local and remote files (at the time of download) | ||
| :param dry_run: If True, will not download files, but instead log what would have been downloaded. | ||
| """ | ||
| print(f'Downloading files for {self.name} context... ' | ||
| f'(include_zs={include_zs}, exists="{exists}", dry_run={dry_run})') | ||
| if dry_run: | ||
| print(f'\tNOTE: `dry_run=True`! Files will not be downloaded.') | ||
| assert exists in ["raise", "ignore", "overwrite"] | ||
| assert self.s3_download_map is not None, \ | ||
| f'self.s3_download_map is None: download functionality is disabled' | ||
| file_paths_expected = self.benchmark_paths.get_file_paths(include_zs=include_zs) | ||
|
|
||
| file_paths_to_download = [f for f in file_paths_expected if f in self.s3_download_map] | ||
| if len(file_paths_to_download) == 0: | ||
| print(f'WARNING: Matching file paths to download is 0! ' | ||
| f'`self.s3_download_map` probably has incorrect keys.') | ||
| file_paths_already_exist = [f for f in file_paths_to_download if self.benchmark_paths.exists(f)] | ||
| file_paths_missing = [f for f in file_paths_to_download if not self.benchmark_paths.exists(f)] | ||
|
|
||
| if exists == 'raise': | ||
| if file_paths_already_exist: | ||
| raise AssertionError(f'`exists="{exists}"`, ' | ||
| f'and found {len(file_paths_already_exist)} files that already exist locally!\n' | ||
| f'\tExisting Files: {file_paths_already_exist}\n' | ||
| f'\tMissing Files: {file_paths_missing}\n' | ||
| f'Either manually inspect and delete existing files, ' | ||
| f'set `exists="ignore"` to keep your local files and only download missing files, ' | ||
| f'or set `exists="overwrite"` to overwrite your existing local files.') | ||
| elif exists == 'ignore': | ||
| file_paths_to_download = file_paths_missing | ||
| elif exists == 'overwrite': | ||
| file_paths_to_download = file_paths_to_download | ||
| if use_s3: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move this into a
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| download_from_s3( | ||
| name=self.name, include_zs=include_zs, exists=exists, dry_run=dry_run, | ||
| s3_download_map=self.s3_download_map, benchmark_paths=self.benchmark_paths, verbose=verbose | ||
| ) | ||
| else: | ||
| raise ValueError(f'Invalid value for exists (`exists="{exists}"`). ' | ||
| f'Valid values: {["raise", "ignore", "overwrite"]}') | ||
|
|
||
| s3_to_local_tuple_list = [(val, key) for key, val in self.s3_download_map.items() | ||
| if key in file_paths_to_download] | ||
|
|
||
| log_extra = '' | ||
|
|
||
| num_exist = len(file_paths_already_exist) | ||
| if exists == 'overwrite': | ||
| if num_exist > 0: | ||
| log_extra += f'\tWill overwrite {num_exist} files that exist locally:\n' \ | ||
| f'\t\t{file_paths_already_exist}' | ||
| else: | ||
| log_extra = f'' | ||
| if exists == 'ignore': | ||
| log_extra += f'\tWill skip {num_exist} files that exist locally:\n' \ | ||
| f'\t\t{file_paths_already_exist}' | ||
| if file_paths_missing: | ||
| if log_extra: | ||
| log_extra += '\n' | ||
| log_extra += f'Will download {len(file_paths_missing)} files that are missing locally:\n' \ | ||
| f'\t\t{file_paths_missing}' | ||
|
|
||
| if log_extra: | ||
| print(log_extra) | ||
| print(f'\tDownloading {len(s3_to_local_tuple_list)} files from s3 to local...') | ||
| for s3_path, local_path in s3_to_local_tuple_list: | ||
| print(f'\t\t"{s3_path}" -> "{local_path}"') | ||
| s3_required_list = [(s3_path, local_path) for s3_path, local_path in s3_to_local_tuple_list if s3_path[:2] == "s3"] | ||
| urllib_required_list = [(s3_path, local_path) for s3_path, local_path in s3_to_local_tuple_list if s3_path[:2] != "s3"] | ||
| if urllib_required_list: | ||
| download_files(remote_to_local_tuple_list=urllib_required_list, dry_run=dry_run, verbose=verbose) | ||
| if s3_required_list: | ||
| download_s3_files(s3_to_local_tuple_list=s3_required_list, dry_run=dry_run, verbose=verbose) | ||
| if verbose: | ||
| print(f'Downloading files for {self.name} context... ' | ||
| f'(include_zs={include_zs}, exists="{exists}")') | ||
| download_from_huggingface( | ||
| datasets=self.benchmark_paths.datasets, | ||
| ) | ||
|
|
||
| def load(self, | ||
| folds: List[int] = None, | ||
| load_predictions: bool = True, | ||
| download_files: bool = True, | ||
| prediction_format: str = "memmap", | ||
| exists: str = 'ignore') -> Tuple[ZeroshotSimulatorContext, TabularModelPredictions, GroundTruth]: | ||
| exists: str = 'ignore', | ||
| use_s3: bool = True, | ||
| ) -> Tuple[ZeroshotSimulatorContext, TabularModelPredictions, GroundTruth]: | ||
| """ | ||
| :param folds: If None, uses self.folds as default. | ||
| If specified, must be a subset of `self.folds`. This will filter the results to only the specified folds. | ||
|
|
@@ -397,7 +420,7 @@ def load(self, | |
| missing_files_str = [f'\n\t"{m}"' for m in missing_files] | ||
| raise FileNotFoundError(f'Missing {len(missing_files)} required files: \n[{",".join(missing_files_str)}\n]') | ||
| print(f'Downloading input files from s3...') | ||
| self.download(include_zs=load_predictions, exists=exists) | ||
| self.download(include_zs=load_predictions, exists=exists, use_s3=use_s3) | ||
| self.benchmark_paths.assert_exists_all(check_zs=load_predictions) | ||
|
|
||
| configs_hyperparameters = self.load_configs_hyperparameters() | ||
|
|
@@ -419,13 +442,15 @@ def load_repo( | |
| download_files: bool = True, | ||
| prediction_format: str = "memmap", | ||
| exists: str = 'ignore', | ||
| use_s3: bool = True, | ||
| ) -> EvaluationRepository: | ||
| zsc, zeroshot_pred_proba, zeroshot_gt = self.load( | ||
| folds=folds, | ||
| load_predictions=load_predictions, | ||
| download_files=download_files, | ||
| prediction_format=prediction_format, | ||
| exists=exists, | ||
| use_s3=use_s3, | ||
| ) | ||
| repo = EvaluationRepository( | ||
| zeroshot_context=zsc, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comment:
I will probably lean towards keeping the s3 download logic in addition to the HF download logic, as it can be useful when dealing with private data that hasn't been approved for public release yet, as it is easier to put this on a private s3 bucket.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, but then wouldnt it makes sense to move the s3 logic in a utils given that it is private?
Let me know if this solution works for you, I can also add a flag to configure whether to use s3 or HF.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main requirement is that there should be a flag to specify easily between the two. For example, if we have internal teams who want to share their internal runs with us / store them for their own use, they probably won't have permission to create a HF repo, so they would need to use S3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add a flag between the two, but would it be sufficient to merge the PR?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a flag with the default still being s3 for now, will probably switch to HF default later on.
Once that is added + the s3 logic is added back in, I'm happy to merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thanks, I added back an option to use S3.
I would recommend to switch to HF though as downloading TabRepo on external machines is currently painful (it takes more than a day). I benchmarked the two on the 30 datasets and HF was ~3x faster (it will be faster I believe when using larger datasets as it has been optimized for this use-case).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely makes sense, I just want to first do a few things (like refactoring configs_hyperparamters) before switching by default.