|
6 | 6 | )
|
7 | 7 | from keras_hub.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
|
8 | 8 | from keras_hub.src.utils.preset_utils import builtin_presets
|
9 |
| -from keras_hub.src.utils.preset_utils import find_subclass |
10 | 9 | from keras_hub.src.utils.preset_utils import get_preset_loader
|
11 | 10 | from keras_hub.src.utils.preset_utils import get_preset_saver
|
12 | 11 | from keras_hub.src.utils.python_utils import classproperty
|
@@ -171,43 +170,38 @@ def from_preset(
|
171 | 170 | )
|
172 | 171 | ```
|
173 | 172 | """
|
174 |
| - if cls == Preprocessor: |
| 173 | + if cls is Preprocessor: |
175 | 174 | raise ValueError(
|
176 | 175 | "Do not call `Preprocessor.from_preset()` directly. Instead "
|
177 | 176 | "choose a particular task preprocessing class, e.g. "
|
178 | 177 | "`keras_hub.models.TextClassifierPreprocessor.from_preset()`."
|
179 | 178 | )
|
180 | 179 |
|
181 | 180 | loader = get_preset_loader(preset)
|
182 |
| - backbone_cls = loader.check_backbone_class() |
183 |
| - # Detect the correct subclass if we need to. |
184 |
| - if cls.backbone_cls != backbone_cls: |
185 |
| - cls = find_subclass(preset, cls, backbone_cls) |
186 |
| - return loader.load_preprocessor(cls, config_file, **kwargs) |
| 181 | + return loader.load_preprocessor( |
| 182 | + cls=cls, config_file=config_file, kwargs=kwargs |
| 183 | + ) |
187 | 184 |
|
188 | 185 | @classmethod
|
189 |
| - def _add_missing_kwargs(cls, loader, kwargs): |
190 |
| - """Fill in required kwargs when loading from preset. |
191 |
| -
|
192 |
| - This is a private method hit when loading a preprocessing layer that |
193 |
| - was not directly saved in the preset. This method should fill in |
194 |
| - all required kwargs required to call the class constructor. For almost, |
195 |
| - all preprocessors, the only required args are `tokenizer`, |
196 |
| - `image_converter`, and `audio_converter`, but this can be overridden, |
197 |
| - e.g. for a preprocessor with multiple tokenizers for different |
198 |
| - encoders. |
| 186 | + def _from_defaults(cls, loader, kwargs): |
| 187 | + """Load a preprocessor from default values. |
| 188 | +
|
| 189 | + This is a private method hit for loading a preprocessing layer that was |
| 190 | + not directly saved in the preset. Usually this means loading a |
| 191 | + tokenizer, image_converter and/or audio_converter and calling the |
| 192 | + constructor. But this can be overridden by subclasses as needed. |
199 | 193 | """
|
| 194 | + defaults = {} |
| 195 | + # Allow loading any tokenizer, image_converter or audio_converter config |
| 196 | + # we find on disk. We allow mixing a matching tokenizers and |
| 197 | + # preprocessing layers (though this is usually not a good idea). |
200 | 198 | if "tokenizer" not in kwargs and cls.tokenizer_cls:
|
201 |
| - kwargs["tokenizer"] = loader.load_tokenizer(cls.tokenizer_cls) |
| 199 | + defaults["tokenizer"] = loader.load_tokenizer() |
202 | 200 | if "audio_converter" not in kwargs and cls.audio_converter_cls:
|
203 |
| - kwargs["audio_converter"] = loader.load_audio_converter( |
204 |
| - cls.audio_converter_cls |
205 |
| - ) |
| 201 | + defaults["audio_converter"] = loader.load_audio_converter() |
206 | 202 | if "image_converter" not in kwargs and cls.image_converter_cls:
|
207 |
| - kwargs["image_converter"] = loader.load_image_converter( |
208 |
| - cls.image_converter_cls |
209 |
| - ) |
210 |
| - return kwargs |
| 203 | + defaults["image_converter"] = loader.load_image_converter() |
| 204 | + return cls(**{**defaults, **kwargs}) |
211 | 205 |
|
212 | 206 | def load_preset_assets(self, preset):
|
213 | 207 | """Load all static assets needed by the preprocessing layer.
|
|
0 commit comments