|
20 | 20 | from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
21 | 21 | from invokeai.app.invocations.model import UNetField, VAEField
|
22 | 22 | from invokeai.app.services.shared.invocation_context import InvocationContext
|
23 |
| -from invokeai.backend.model_manager import LoadedModel |
24 |
| -from invokeai.backend.model_manager.config import Main_Config_Base |
25 |
| -from invokeai.backend.model_manager.taxonomy import ModelVariantType |
| 23 | +from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelType, ModelVariantType |
26 | 24 | from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
27 | 25 |
|
28 | 26 |
|
@@ -182,10 +180,11 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
182 | 180 | if self.unet is not None and self.vae is not None and self.image is not None:
|
183 | 181 | # all three fields must be present at the same time
|
184 | 182 | main_model_config = context.models.get_config(self.unet.unet.key)
|
185 |
| - assert isinstance(main_model_config, Main_Config_Base) |
186 |
| - if main_model_config.variant is ModelVariantType.Inpaint: |
| 183 | + assert main_model_config.type is ModelType.Main |
| 184 | + variant = getattr(main_model_config, "variant", None) |
| 185 | + if variant is ModelVariantType.Inpaint or variant is FluxVariantType.DevFill: |
187 | 186 | mask = dilated_mask_tensor
|
188 |
| - vae_info: LoadedModel = context.models.load(self.vae.vae) |
| 187 | + vae_info = context.models.load(self.vae.vae) |
189 | 188 | image = context.images.get_pil(self.image.image_name)
|
190 | 189 | image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
191 | 190 | if image_tensor.dim() == 3:
|
|
0 commit comments