|
98 | 98 | from autosklearn.smbo import AutoMLSMBO |
99 | 99 | from autosklearn.util import RE_PATTERN, pipeline |
100 | 100 | from autosklearn.util.dask import Dask, LocalDask, UserDask |
101 | | -from autosklearn.util.data import ( |
102 | | - DatasetCompressionSpec, |
103 | | - default_dataset_compression_arg, |
104 | | - reduce_dataset_size_if_too_large, |
105 | | - supported_precision_reductions, |
106 | | - validate_dataset_compression_arg, |
107 | | -) |
| 101 | +from autosklearn.util.data import DatasetCompression |
108 | 102 | from autosklearn.util.logging_ import ( |
109 | 103 | PicklableClientLogger, |
110 | 104 | get_named_client_logger, |
@@ -252,15 +246,14 @@ def __init__( |
252 | 246 | ) |
253 | 247 |
|
254 | 248 | # Validate dataset_compression and set its values |
255 | | - self._dataset_compression: DatasetCompressionSpec | None = None |
256 | | - if isinstance(dataset_compression, bool): |
257 | | - if dataset_compression is True: |
258 | | - self._dataset_compression = default_dataset_compression_arg |
259 | | - else: |
260 | | - self._dataset_compression = validate_dataset_compression_arg( |
261 | | - dataset_compression, |
262 | | - memory_limit=memory_limit, |
263 | | - ) |
| 249 | + self._dataset_compression: DatasetCompression | None = None |
| 250 | + if dataset_compression is not False: |
| 251 | + |
| 252 | + if memory_limit is None: |
| 253 | + raise ValueError("Must provide a `memory_limit` for data compression") |
| 254 | + |
| 255 | + spec = {} if dataset_compression is True else dataset_compression |
| 256 | + self._dataset_compression = DatasetCompression(**spec, limit=memory_limit) |
264 | 257 |
|
265 | 258 | # If we got something callable for `get_trials_callback`, wrap it so SMAC |
266 | 259 | # will accept it. |
@@ -667,30 +660,13 @@ def fit( |
667 | 660 | X_test, y_test = input_validator.transform(X_test, y_test) |
668 | 661 |
|
669 | 662 | # We don't support size reduction on pandas type object yet |
670 | | - if ( |
671 | | - self._dataset_compression is not None |
672 | | - and not isinstance(X, pd.DataFrame) |
673 | | - and not (isinstance(y, pd.Series) or isinstance(y, pd.DataFrame)) |
674 | | - ): |
675 | | - methods = self._dataset_compression["methods"] |
676 | | - memory_allocation = self._dataset_compression["memory_allocation"] |
677 | | - |
678 | | - # Remove precision reduction if we can't perform it |
679 | | - if ( |
680 | | - "precision" in methods |
681 | | - and X.dtype not in supported_precision_reductions |
682 | | - ): |
683 | | - methods = [method for method in methods if method != "precision"] |
684 | | - |
| 663 | + if self._dataset_compression and self._dataset_compression.supports(X, y): |
685 | 664 | with warnings_to(self.logger): |
686 | | - X, y = reduce_dataset_size_if_too_large( |
| 665 | + X, y = self._dataset_compression.compress( |
687 | 666 | X=X, |
688 | 667 | y=y, |
689 | | - memory_limit=self._memory_limit, |
690 | | - is_classification=self.is_classification, |
| 668 | + stratify=self.is_classification, |
691 | 669 | random_state=self._seed, |
692 | | - operations=methods, |
693 | | - memory_allocation=memory_allocation, |
694 | 670 | ) |
695 | 671 |
|
696 | 672 | # Check the re-sampling strategy |
|
0 commit comments