diff --git a/docs/contributing/RECALL_PARAMETERS/RECALL_API_LORAS_CONTROLNETS_IMAGES.md b/docs/contributing/RECALL_PARAMETERS/RECALL_API_LORAS_CONTROLNETS_IMAGES.md new file mode 100644 index 00000000000..75d68e44c0c --- /dev/null +++ b/docs/contributing/RECALL_PARAMETERS/RECALL_API_LORAS_CONTROLNETS_IMAGES.md @@ -0,0 +1,375 @@ +# Recall Parameters API - LoRAs, ControlNets, and IP Adapters with Images + +## Overview + +The Recall Parameters API supports recalling LoRAs, ControlNets (including T2I Adapters and Control LoRAs), and IP Adapters along with their associated weights and settings. Control Layers and IP Adapters can now include image references from the `INVOKEAI_ROOT/outputs/images` directory for fully functional control and image prompt functionality. + +## Key Features + +✅ **LoRAs**: Fully functional - adds to UI, queries model configs, applies weights +✅ **Control Layers**: Full support with optional images from outputs/images +✅ **IP Adapters**: Full support with optional reference images from outputs/images +✅ **Model Name Resolution**: Automatic lookup from human-readable names to internal keys +✅ **Image Validation**: Backend validates that image files exist before sending + +## Endpoints + +### POST `/api/v1/recall/{queue_id}` + +Updates recallable parameters for the frontend, including LoRAs, control adapters, and IP adapters with optional images. + +**Path Parameters:** +- `queue_id` (string): The queue ID to associate parameters with (typically "default") + +**Request Body:** + +All fields are optional. Include only the parameters you want to update. + +```typescript +{ + // Standard parameters + positive_prompt?: string; + negative_prompt?: string; + model?: string; // Model name or key + steps?: number; + cfg_scale?: number; + width?: number; + height?: number; + seed?: number; + // ... other standard parameters + + // LoRAs + loras?: Array<{ + model_name: string; // LoRA model name + weight?: number; // Default: 0.75, Range: -10 to 10 + is_enabled?: boolean; // Default: true + }>; + + // Control Layers (ControlNet, T2I Adapter, Control LoRA) + control_layers?: Array<{ + model_name: string; // Control adapter model name + image_name?: string; // Optional image filename from outputs/images + weight?: number; // Default: 1.0, Range: -1 to 2 + begin_step_percent?: number; // Default: 0.0, Range: 0 to 1 + end_step_percent?: number; // Default: 1.0, Range: 0 to 1 + control_mode?: "balanced" | "more_prompt" | "more_control"; // ControlNet only + }>; + + // IP Adapters + ip_adapters?: Array<{ + model_name: string; // IP Adapter model name + image_name?: string; // Optional reference image filename from outputs/images + weight?: number; // Default: 1.0, Range: -1 to 2 + begin_step_percent?: number; // Default: 0.0, Range: 0 to 1 + end_step_percent?: number; // Default: 1.0, Range: 0 to 1 + method?: "full" | "style" | "composition"; // Default: "full" + influence?: "Lowest" | "Low" | "Medium" | "High" | "Highest"; // Flux Redux only; default: "highest" + }>; +} +``` + +## Model Name Resolution + +The backend automatically resolves model names to their internal keys: + +1. **Main Models**: Resolved from the name to the model key +2. **LoRAs**: Searched in the LoRA model database +3. **Control Adapters**: Tried in order - ControlNet → T2I Adapter → Control LoRA +4. **IP Adapters**: Searched in the IP Adapter model database + +Models that cannot be resolved are skipped with a warning in the logs. + +## Image File Handling + +### Image Path Resolution + +When you specify an `image_name`, the backend: +1. Constructs the full path: `{INVOKEAI_ROOT}/outputs/images/{image_name}` +2. Validates that the file exists +3. Includes the image reference in the event sent to the frontend +4. Logs whether the image was found or not + +### Image Naming + +Images should be referenced by their filename as it appears in the outputs/images directory: +- ✅ Correct: `"image_name": "example.png"` +- ✅ Correct: `"image_name": "my_control_image_20240110.jpg"` +- ❌ Incorrect: `"image_name": "outputs/images/example.png"` (use relative filename only) +- ❌ Incorrect: `"image_name": "/full/path/to/example.png"` (use relative filename only) + +## Frontend Behavior + +### LoRAs +- **Fully Supported**: LoRAs are immediately added to the LoRA list in the UI +- Existing LoRAs are cleared before adding new ones +- Each LoRA's model config is fetched and applied with the specified weight +- LoRAs appear in the LoRA selector panel + +### Control Layers with Images +- **Fully Supported**: Control layers now support images from outputs/images +- Configuration includes model, weights, step percentages, and image reference +- Image availability is logged in frontend console +- Images can be used to create actual control layers through the UI + +### IP Adapters with Images +- **Fully Supported**: IP Adapters now support reference images from outputs/images +- Configuration includes model, weights, step percentages, method, and image reference +- Image availability is logged in frontend console +- Images can be used to create actual reference image layers through the UI + +## Examples + +### 1. Add LoRAs Only + +```bash +curl -X POST http://localhost:9090/api/v1/recall/default \ + -H "Content-Type: application/json" \ + -d '{ + "loras": [ + { + "model_name": "add-detail-xl", + "weight": 0.8, + "is_enabled": true + }, + { + "model_name": "sd_xl_offset_example-lora_1.0", + "weight": 0.5, + "is_enabled": true + } + ] + }' +``` + +### 2. Configure Control Layers with Image + +Replace `my_control_image.png` with an actual image filename from your outputs/images directory. + +```bash +curl -X POST http://localhost:9090/api/v1/recall/default \ + -H "Content-Type: application/json" \ + -d '{ + "control_layers": [ + { + "model_name": "controlnet-canny-sdxl-1.0", + "image_name": "my_control_image.png", + "weight": 0.75, + "begin_step_percent": 0.0, + "end_step_percent": 0.8, + "control_mode": "balanced" + } + ] + }' +``` + +### 3. Configure IP Adapters with Reference Image + +Replace `reference_face.png` with an actual image filename from your outputs/images directory. + +```bash +curl -X POST http://localhost:9090/api/v1/recall/default \ + -H "Content-Type: application/json" \ + -d '{ + "ip_adapters": [ + { + "model_name": "ip-adapter-plus-face_sd15", + "image_name": "reference_face.png", + "weight": 0.7, + "begin_step_percent": 0.0, + "end_step_percent": 1.0, + "method": "composition" + } + ] + }' +``` + +### 4. Complete Configuration with All Features + +```bash +curl -X POST http://localhost:9090/api/v1/recall/default \ + -H "Content-Type: application/json" \ + -d '{ + "positive_prompt": "masterpiece, detailed photo with specific style", + "negative_prompt": "blurry, low quality", + "model": "FLUX Schnell", + "steps": 25, + "cfg_scale": 8.0, + "width": 1024, + "height": 768, + "seed": 42, + "loras": [ + { + "model_name": "add-detail-xl", + "weight": 0.6, + "is_enabled": true + } + ], + "control_layers": [ + { + "model_name": "controlnet-depth-sdxl-1.0", + "image_name": "depth_map.png", + "weight": 1.0, + "begin_step_percent": 0.0, + "end_step_percent": 0.7 + } + ], + "ip_adapters": [ + { + "model_name": "ip-adapter-plus-face_sd15", + "image_name": "style_reference.png", + "weight": 0.5, + "begin_step_percent": 0.0, + "end_step_percent": 1.0, + "method": "style" + } + ] + }' +``` + +## Response Format + +```json +{ + "status": "success", + "queue_id": "default", + "updated_count": 15, + "parameters": { + "positive_prompt": "...", + "steps": 25, + "loras": [ + { + "model_key": "abc123...", + "weight": 0.6, + "is_enabled": true + } + ], + "control_layers": [ + { + "model_key": "controlnet-xyz...", + "weight": 1.0, + "image": { + "image_name": "depth_map.png" + } + } + ], + "ip_adapters": [ + { + "model_key": "ip-adapter-xyz...", + "weight": 0.5, + "image": { + "image_name": "style_reference.png" + } + } + ] + } +} +``` + +## WebSocket Events + +When parameters are updated, a `recall_parameters_updated` event is emitted via WebSocket to the queue room. The frontend automatically: + +1. Applies standard parameters (prompts, steps, dimensions, etc.) +2. Loads and adds LoRAs to the LoRA list +3. Logs control layer and IP adapter configurations with image information +4. Makes image references available for manual canvas/reference image creation + +## Logging + +### Backend Logs + +Backend logs show: +- Model name → key resolution (success/failure) +- Image file validation (found/not found) +- Parameter storage confirmation +- Event emission status + +Example log messages: +``` +INFO: Resolved ControlNet model name 'controlnet-canny-sdxl-1.0' to key 'controlnet-xyz...' +INFO: Found image file: depth_map.png +INFO: Updated 12 recall parameters for queue default +INFO: Resolved 1 LoRA(s) +INFO: Resolved 1 control layer(s) +INFO: Resolved 1 IP adapter(s) +``` + +### Frontend Logs + +Frontend logs (check browser console): +- Set `localStorage.ROARR_FILTER = 'debug'` to see all debug messages +- Look for messages from the `events` namespace +- LoRA loading, model resolution, and parameter application are logged + +Example log messages: +``` +INFO: Applied 5 recall parameters to store +INFO: Received 1 control layer(s) with image support +INFO: Control layer 1: controlnet-xyz... (weight: 0.75, image: depth_map.png) +DEBUG: Control layer 1 image available at: outputs/images/depth_map.png +INFO: Received 1 IP adapter(s) with image support +INFO: IP adapter 1: ip-adapter-xyz... (weight: 0.7, image: style_reference.png) +DEBUG: IP adapter 1 image available at: outputs/images/style_reference.png +``` + +## Limitations + +1. **Canvas Integration**: Control layers and IP adapters with images are currently logged but not automatically added to canvas layers + - Users can view the configuration and manually create canvas layers with the provided images + - Future enhancement: Auto-create canvas layers with stored images + +2. **Model Availability**: Models must be installed in InvokeAI before they can be recalled + +3. **Image Availability**: Images must exist in the outputs/images directory + - Missing images are logged as warnings but don't fail the request + - Other parameters are still applied even if images are missing + +4. **Image URLs**: Only local filenames from outputs/images are supported + - Remote image URLs are not currently supported + +## Testing + +Use the provided test script: + +```bash +./test_recall_loras_controlnets.sh +``` + +This will test: +- LoRA addition with multiple models +- Control layer configuration with image references +- IP adapter configuration with image references +- Combined parameter updates with all features + +Note: Update the image names in the test script to match actual images in your outputs/images directory. + +## Troubleshooting + +### Images Not Found + +If you see "Image file not found" in the logs: +1. Verify the image filename matches exactly (case-sensitive) +2. Ensure the image is in `{INVOKEAI_ROOT}/outputs/images/` +3. Check that the filename doesn't include the `outputs/images/` prefix + +### Models Not Found + +If you see "Could not find model" messages: +1. Verify the model name matches exactly (case-sensitive) +2. Ensure the model is installed in InvokeAI +3. Check the model name using the models browser in the UI + +### Event Not Received + +If the frontend doesn't receive the event: +1. Check browser console for connection errors +2. Verify the queue_id matches the frontend's queue (usually "default") +3. Check backend logs for event emission errors + +## Future Enhancements + +Potential improvements: +1. Auto-create canvas layers with provided control layer images +2. Auto-create reference image layers with provided IP adapter images +3. Support for image URLs +4. Batch operations for multiple queue IDs +5. Image upload capability (accept base64 or file upload) diff --git a/docs/contributing/RECALL_PARAMETERS/RECALL_PARAMETERS_API.md b/docs/contributing/RECALL_PARAMETERS/RECALL_PARAMETERS_API.md new file mode 100644 index 00000000000..0c44cf38e34 --- /dev/null +++ b/docs/contributing/RECALL_PARAMETERS/RECALL_PARAMETERS_API.md @@ -0,0 +1,208 @@ +# Recall Parameters API + +## Overview + +A new REST API endpoint has been added to the InvokeAI backend that allows programmatic updates to recallable parameters from another process. This enables external applications or scripts to modify frontend parameters like prompts, models, and step counts via HTTP requests. + +When parameters are updated via the API, the backend automatically broadcasts a WebSocket event to all connected frontend clients subscribed to that queue, causing them to update immediately. + +## How It Works + +1. **API Request**: External application sends a POST request with parameters to update +2. **Storage**: Parameters are stored in client state persistence, associated with a queue ID +3. **Broadcast**: A WebSocket event (`recall_parameters_updated`) is emitted to all frontend clients listening to that queue +4. **Frontend Update**: Connected frontend clients receive the event and can process the updated parameters +5. **Immediate Display**: The frontend UI updates automatically with the new values + +This means if you have the InvokeAI frontend open in a browser, updating parameters via the API will instantly reflect on the screen without any manual action needed. + +## Endpoint + +**Base URL**: `http://localhost:9090/api/v1/recall/{queue_id}` + +## POST - Update Recall Parameters + +Updates recallable parameters for a given queue ID. + +### Request + +```http +POST /api/v1/recall/{queue_id} +Content-Type: application/json + +{ + "positive_prompt": "a beautiful landscape", + "negative_prompt": "blurry, low quality", + "model": "sd-1.5", + "steps": 20, + "cfg_scale": 7.5, + "width": 512, + "height": 512, + "seed": 12345 +} +``` + +The queue id is usually "default". + +### Parameters + +All parameters are optional. Only provide the parameters you want to update: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `positive_prompt` | string | Positive prompt text | +| `negative_prompt` | string | Negative prompt text | +| `model` | string | Main model name/identifier | +| `refiner_model` | string | Refiner model name/identifier | +| `vae_model` | string | VAE model name/identifier | +| `scheduler` | string | Scheduler name | +| `steps` | integer | Number of generation steps (≥1) | +| `refiner_steps` | integer | Number of refiner steps (≥0) | +| `cfg_scale` | number | CFG scale for guidance | +| `cfg_rescale_multiplier` | number | CFG rescale multiplier | +| `refiner_cfg_scale` | number | Refiner CFG scale | +| `guidance` | number | Guidance scale | +| `width` | integer | Image width in pixels (≥64) | +| `height` | integer | Image height in pixels (≥64) | +| `seed` | integer | Random seed (≥0) | +| `denoise_strength` | number | Denoising strength (0-1) | +| `refiner_denoise_start` | number | Refiner denoising start (0-1) | +| `clip_skip` | integer | CLIP skip layers (≥0) | +| `seamless_x` | boolean | Enable seamless X tiling | +| `seamless_y` | boolean | Enable seamless Y tiling | +| `refiner_positive_aesthetic_score` | number | Refiner positive aesthetic score | +| `refiner_negative_aesthetic_score` | number | Refiner negative aesthetic score | + +### Response + +```json +{ + "status": "success", + "queue_id": "queue_123", + "updated_count": 7, + "parameters": { + "positive_prompt": "a beautiful landscape", + "negative_prompt": "blurry, low quality", + "model": "sd-1.5", + "steps": 20, + "cfg_scale": 7.5, + "width": 512, + "height": 512, + "seed": 12345 + } +} +``` + +## GET - Retrieve Recall Parameters + +Retrieves metadata about stored recall parameters. + +### Request + +```http +GET /api/v1/recall/{queue_id} +``` + +### Response + +```json +{ + "status": "success", + "queue_id": "queue_123", + "note": "Use the frontend to access stored recall parameters, or set specific parameters using POST" +} +``` + +## Usage Examples + +### Using cURL + +```bash +# Update prompts and model +curl -X POST http://localhost:9090/api/v1/recall/default \ + -H "Content-Type: application/json" \ + -d '{ + "positive_prompt": "a cyberpunk city at night", + "negative_prompt": "dark, unclear", + "model": "sd-1.5", + "steps": 30 + }' + +# Update just the seed +curl -X POST http://localhost:9090/api/v1/recall/default \ + -H "Content-Type: application/json" \ + -d '{"seed": 99999}' +``` + +### Using Python + +```python +import requests +import json + +# Configuration +API_URL = "http://localhost:9090/api/v1/recall/default" + +# Update multiple parameters +params = { + "positive_prompt": "a serene forest", + "negative_prompt": "people, buildings", + "steps": 25, + "cfg_scale": 7.0, + "seed": 42 +} + +response = requests.post(API_URL, json=params) +result = response.json() + +print(f"Status: {result['status']}") +print(f"Updated {result['updated_count']} parameters") +print(json.dumps(result['parameters'], indent=2)) +``` + +### Using Node.js/JavaScript + +```javascript +const API_URL = 'http://localhost:9090/api/v1/recall/default'; + +const params = { + positive_prompt: 'a beautiful sunset', + negative_prompt: 'blurry', + steps: 20, + width: 768, + height: 768, + seed: 12345 +}; + +fetch(API_URL, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(params) +}) + .then(res => res.json()) + .then(data => console.log(data)); +``` + +## Implementation Details + +- Parameters are stored in the client state persistence service, using keys prefixed with `recall_` +- The parameters are associated with a `queue_id`, allowing multiple concurrent sessions to maintain separate parameter sets +- Only non-null parameters are processed and stored +- The endpoint provides validation for numeric ranges (e.g., steps ≥ 1, dimensions ≥ 64) +- All parameter values are JSON-serialized for storage +- When parameter values are changed, the backend generates a web sockets event that the frontend listens to. + +## Integration with Frontend + +The stored parameters can be accessed by the frontend through the +existing client state API or by implementing hooks that read from the +recall parameter storage. This allows external applications to +pre-populate generation parameters before the user initiates image +generation. + +## Error Handling + +- **400 Bad Request**: Invalid parameters or parameter values +- **500 Internal Server Error**: Server-side error storing or retrieving parameters + +Errors include detailed messages explaining what went wrong. diff --git a/invokeai/app/api/routers/recall_parameters.py b/invokeai/app/api/routers/recall_parameters.py new file mode 100644 index 00000000000..0af3fd29b0c --- /dev/null +++ b/invokeai/app/api/routers/recall_parameters.py @@ -0,0 +1,458 @@ +"""Router for updating recallable parameters on the frontend.""" + +import json +from typing import Any, Literal, Optional + +from fastapi import Body, HTTPException, Path +from fastapi.routing import APIRouter +from pydantic import BaseModel, ConfigDict, Field + +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.backend.image_util.controlnet_processor import process_controlnet_image +from invokeai.backend.model_manager.taxonomy import ModelType + +recall_parameters_router = APIRouter(prefix="/v1/recall", tags=["recall"]) + + +class LoRARecallParameter(BaseModel): + """LoRA configuration for recall""" + + model_name: str = Field(description="The name of the LoRA model") + weight: float = Field(default=0.75, ge=-10, le=10, description="The weight for the LoRA") + is_enabled: bool = Field(default=True, description="Whether the LoRA is enabled") + + +class ControlNetRecallParameter(BaseModel): + """ControlNet configuration for recall""" + + model_name: str = Field(description="The name of the ControlNet/T2I Adapter/Control LoRA model") + image_name: Optional[str] = Field(default=None, description="The filename of the control image in outputs/images") + weight: float = Field(default=1.0, ge=-1, le=2, description="The weight for the control adapter") + begin_step_percent: Optional[float] = Field( + default=None, ge=0, le=1, description="When the control adapter is first applied (% of total steps)" + ) + end_step_percent: Optional[float] = Field( + default=None, ge=0, le=1, description="When the control adapter is last applied (% of total steps)" + ) + control_mode: Optional[Literal["balanced", "more_prompt", "more_control"]] = Field( + default=None, description="The control mode (ControlNet only)" + ) + + +class IPAdapterRecallParameter(BaseModel): + """IP Adapter configuration for recall""" + + model_name: str = Field(description="The name of the IP Adapter model") + image_name: Optional[str] = Field(default=None, description="The filename of the reference image in outputs/images") + weight: float = Field(default=1.0, ge=-1, le=2, description="The weight for the IP Adapter") + begin_step_percent: Optional[float] = Field( + default=None, ge=0, le=1, description="When the IP Adapter is first applied (% of total steps)" + ) + end_step_percent: Optional[float] = Field( + default=None, ge=0, le=1, description="When the IP Adapter is last applied (% of total steps)" + ) + method: Optional[Literal["full", "style", "composition"]] = Field(default=None, description="The IP Adapter method") + image_influence: Optional[Literal["lowest", "low", "medium", "high", "highest"]] = Field( + default=None, description="FLUX Redux image influence (if model is flux_redux)" + ) + + +class RecallParameter(BaseModel): + """Request model for updating recallable parameters.""" + + model_config = ConfigDict(extra="forbid") + + # Prompts + positive_prompt: Optional[str] = Field(None, description="Positive prompt text") + negative_prompt: Optional[str] = Field(None, description="Negative prompt text") + + # Model configuration + model: Optional[str] = Field(None, description="Main model name/identifier") + refiner_model: Optional[str] = Field(None, description="Refiner model name/identifier") + vae_model: Optional[str] = Field(None, description="VAE model name/identifier") + scheduler: Optional[str] = Field(None, description="Scheduler name") + + # Generation parameters + steps: Optional[int] = Field(None, ge=1, description="Number of generation steps") + refiner_steps: Optional[int] = Field(None, ge=0, description="Number of refiner steps") + cfg_scale: Optional[float] = Field(None, description="CFG scale for guidance") + cfg_rescale_multiplier: Optional[float] = Field(None, description="CFG rescale multiplier") + refiner_cfg_scale: Optional[float] = Field(None, description="Refiner CFG scale") + guidance: Optional[float] = Field(None, description="Guidance scale") + + # Image parameters + width: Optional[int] = Field(None, ge=64, description="Image width in pixels") + height: Optional[int] = Field(None, ge=64, description="Image height in pixels") + seed: Optional[int] = Field(None, ge=0, description="Random seed") + + # Advanced parameters + denoise_strength: Optional[float] = Field(None, ge=0, le=1, description="Denoising strength") + refiner_denoise_start: Optional[float] = Field(None, ge=0, le=1, description="Refiner denoising start") + clip_skip: Optional[int] = Field(None, ge=0, description="CLIP skip layers") + seamless_x: Optional[bool] = Field(None, description="Enable seamless X tiling") + seamless_y: Optional[bool] = Field(None, description="Enable seamless Y tiling") + + # Refiner aesthetics + refiner_positive_aesthetic_score: Optional[float] = Field(None, description="Refiner positive aesthetic score") + refiner_negative_aesthetic_score: Optional[float] = Field(None, description="Refiner negative aesthetic score") + + # LoRAs, ControlNets, and IP Adapters + loras: Optional[list[LoRARecallParameter]] = Field(None, description="List of LoRAs with their weights") + control_layers: Optional[list[ControlNetRecallParameter]] = Field( + None, description="List of control adapters (ControlNet, T2I Adapter, Control LoRA) with their settings" + ) + ip_adapters: Optional[list[IPAdapterRecallParameter]] = Field( + None, description="List of IP Adapters with their settings" + ) + + +def resolve_model_name_to_key(model_name: str, model_type: ModelType = ModelType.Main) -> Optional[str]: + """ + Look up a model by name and return its key. + + Args: + model_name: The name of the model to look up + model_type: The type of model to search for (default: Main) + + Returns: + The key of the first matching model, or None if not found. + """ + logger = ApiDependencies.invoker.services.logger + try: + models = ApiDependencies.invoker.services.model_manager.store.search_by_attr( + model_name=model_name, model_type=model_type + ) + + if models: + logger.info(f"Resolved {model_type.value} model name '{model_name}' to key '{models[0].key}'") + return models[0].key + + logger.warning(f"Could not find {model_type.value} model with name '{model_name}'") + return None + except Exception as e: + logger.error(f"Exception during {model_type.value} model lookup: {e}", exc_info=True) + return None + + +def load_image_file(image_name: str) -> Optional[dict[str, Any]]: + """ + Load an image from the outputs/images directory. + + Args: + image_name: The filename of the image in outputs/images + + Returns: + A dictionary with image_name, width, and height, or None if the image cannot be found + """ + logger = ApiDependencies.invoker.services.logger + try: + # Prefer using the image_files service to validate & open images + image_files = ApiDependencies.invoker.services.image_files + # Resolve a safe path inside outputs + image_path = image_files.get_path(image_name) + + if not image_files.validate_path(str(image_path)): + logger.warning(f"Image file not found: {image_name} (searched in {image_path.parent})") + return None + + # Open the image via service to leverage caching + pil_image = image_files.get(image_name) + width, height = pil_image.size + logger.info(f"Found image file: {image_name} ({width}x{height})") + return {"image_name": image_name, "width": width, "height": height} + except Exception as e: + logger.warning(f"Error loading image file {image_name}: {e}") + return None + + +def resolve_lora_models(loras: list[LoRARecallParameter]) -> list[dict[str, Any]]: + """ + Resolve LoRA model names to keys and build configuration list. + + Args: + loras: List of LoRA recall parameters + + Returns: + List of resolved LoRA configurations with model keys + """ + logger = ApiDependencies.invoker.services.logger + resolved_loras = [] + + for lora in loras: + model_key = resolve_model_name_to_key(lora.model_name, ModelType.LoRA) + if model_key: + resolved_loras.append({"model_key": model_key, "weight": lora.weight, "is_enabled": lora.is_enabled}) + else: + logger.warning(f"Skipping LoRA '{lora.model_name}' - model not found") + + return resolved_loras + + +def resolve_control_models(control_layers: list[ControlNetRecallParameter]) -> list[dict[str, Any]]: + """ + Resolve control adapter model names to keys and build configuration list. + + Tries to resolve as ControlNet, T2I Adapter, or Control LoRA in that order. + + Args: + control_layers: List of control adapter recall parameters + + Returns: + List of resolved control adapter configurations with model keys + """ + logger = ApiDependencies.invoker.services.logger + services = ApiDependencies.invoker.services + resolved_controls = [] + + for control in control_layers: + model_key = None + + # Try ControlNet first + model_key = resolve_model_name_to_key(control.model_name, ModelType.ControlNet) + if not model_key: + # Try T2I Adapter + model_key = resolve_model_name_to_key(control.model_name, ModelType.T2IAdapter) + if not model_key: + # Try Control LoRA (also uses LoRA type) + model_key = resolve_model_name_to_key(control.model_name, ModelType.LoRA) + + if model_key: + config: dict[str, Any] = {"model_key": model_key, "weight": control.weight} + if control.image_name is not None: + image_data = load_image_file(control.image_name) + if image_data: + config["image"] = image_data + + # Try to process the image using the model's default processor + processed_image_data = process_controlnet_image(control.image_name, model_key, services) + if processed_image_data: + config["processed_image"] = processed_image_data + logger.info(f"Added processed image for control adapter {control.model_name}") + else: + logger.warning(f"Could not load image for control adapter: {control.image_name}") + if control.begin_step_percent is not None: + config["begin_step_percent"] = control.begin_step_percent + if control.end_step_percent is not None: + config["end_step_percent"] = control.end_step_percent + if control.control_mode is not None: + config["control_mode"] = control.control_mode + + resolved_controls.append(config) + else: + logger.warning(f"Skipping control adapter '{control.model_name}' - model not found") + + return resolved_controls + + +def resolve_ip_adapter_models(ip_adapters: list[IPAdapterRecallParameter]) -> list[dict[str, Any]]: + """ + Resolve IP Adapter model names to keys and build configuration list. + + Args: + ip_adapters: List of IP Adapter recall parameters + + Returns: + List of resolved IP Adapter configurations with model keys + """ + logger = ApiDependencies.invoker.services.logger + resolved_adapters = [] + + for adapter in ip_adapters: + # Try resolving as IP Adapter; if not found, try FLUX Redux + model_key = resolve_model_name_to_key(adapter.model_name, ModelType.IPAdapter) + if not model_key: + model_key = resolve_model_name_to_key(adapter.model_name, ModelType.FluxRedux) + if model_key: + config: dict[str, Any] = { + "model_key": model_key, + # Always include weight; ignored by FLUX Redux on the frontend + "weight": adapter.weight, + } + if adapter.image_name is not None: + image_data = load_image_file(adapter.image_name) + if image_data: + config["image"] = image_data + else: + logger.warning(f"Could not load image for IP Adapter: {adapter.image_name}") + if adapter.begin_step_percent is not None: + config["begin_step_percent"] = adapter.begin_step_percent + if adapter.end_step_percent is not None: + config["end_step_percent"] = adapter.end_step_percent + if adapter.method is not None: + config["method"] = adapter.method + # Include FLUX Redux image influence when provided + if adapter.image_influence is not None: + config["image_influence"] = adapter.image_influence + + resolved_adapters.append(config) + else: + logger.warning(f"Skipping IP Adapter '{adapter.model_name}' - model not found") + + return resolved_adapters + + +@recall_parameters_router.post( + "/{queue_id}", + operation_id="update_recall_parameters", + response_model=dict[str, Any], +) +async def update_recall_parameters( + queue_id: str = Path(..., description="The queue id to perform this operation on"), + parameters: RecallParameter = Body(..., description="Recall parameters to update"), +) -> dict[str, Any]: + """ + Update recallable parameters that can be recalled on the frontend. + + This endpoint allows updating parameters such as prompt, model, steps, and other + generation settings. These parameters are stored in client state and can be + accessed by the frontend to populate UI elements. + + Args: + queue_id: The queue ID to associate these parameters with + parameters: The RecallParameter object containing the parameters to update + + Returns: + A dictionary containing the updated parameters and status + + Example: + POST /api/v1/recall/{queue_id} + { + "positive_prompt": "a beautiful landscape", + "model": "sd-1.5", + "steps": 20, + "cfg_scale": 7.5, + "width": 512, + "height": 512, + "seed": 12345 + } + """ + logger = ApiDependencies.invoker.services.logger + + try: + # Get only the parameters that were actually provided (non-None values) + provided_params = {k: v for k, v in parameters.model_dump().items() if v is not None} + + if not provided_params: + return {"status": "no_parameters_provided", "updated_count": 0} + + # Store each parameter in client state using a consistent key format + updated_count = 0 + for param_key, param_value in provided_params.items(): + # Convert parameter values to JSON strings for storage + value_str = json.dumps(param_value) + try: + ApiDependencies.invoker.services.client_state_persistence.set_by_key( + queue_id, f"recall_{param_key}", value_str + ) + updated_count += 1 + except Exception as e: + logger.error(f"Error setting recall parameter {param_key}: {e}") + raise HTTPException( + status_code=500, + detail=f"Error setting recall parameter {param_key}", + ) + + logger.info(f"Updated {updated_count} recall parameters for queue {queue_id}") + + # Resolve model name to key if a model was provided + if "model" in provided_params and isinstance(provided_params["model"], str): + model_name = provided_params["model"] + model_key = resolve_model_name_to_key(model_name, ModelType.Main) + + if model_key: + logger.info(f"Resolved model name '{model_name}' to key '{model_key}'") + provided_params["model"] = model_key + else: + logger.warning(f"Could not resolve model name '{model_name}' to a model key") + # Remove model from parameters if we couldn't resolve it + del provided_params["model"] + + # Process LoRAs if provided + if "loras" in provided_params: + loras_param = parameters.loras + if loras_param is not None: + resolved_loras = resolve_lora_models(loras_param) + provided_params["loras"] = resolved_loras + logger.info(f"Resolved {len(resolved_loras)} LoRA(s)") + + # Process control layers if provided + if "control_layers" in provided_params: + control_layers_param = parameters.control_layers + if control_layers_param is not None: + resolved_controls = resolve_control_models(control_layers_param) + provided_params["control_layers"] = resolved_controls + logger.info(f"Resolved {len(resolved_controls)} control layer(s)") + + # Process IP adapters if provided + if "ip_adapters" in provided_params: + ip_adapters_param = parameters.ip_adapters + if ip_adapters_param is not None: + resolved_adapters = resolve_ip_adapter_models(ip_adapters_param) + provided_params["ip_adapters"] = resolved_adapters + logger.info(f"Resolved {len(resolved_adapters)} IP adapter(s)") + + # Emit event to notify frontend of parameter updates + try: + logger.info( + f"Emitting recall_parameters_updated event for queue {queue_id} with {len(provided_params)} parameters" + ) + ApiDependencies.invoker.services.events.emit_recall_parameters_updated(queue_id, provided_params) + logger.info("Successfully emitted recall_parameters_updated event") + except Exception as e: + logger.error(f"Error emitting recall parameters event: {e}", exc_info=True) + # Don't fail the request if event emission fails, just log it + + return { + "status": "success", + "queue_id": queue_id, + "updated_count": updated_count, + "parameters": provided_params, + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating recall parameters: {e}") + raise HTTPException( + status_code=500, + detail="Error updating recall parameters", + ) + + +@recall_parameters_router.get( + "/{queue_id}", + operation_id="get_recall_parameters", + response_model=dict[str, Any], +) +async def get_recall_parameters( + queue_id: str = Path(..., description="The queue id to retrieve parameters for"), +) -> dict[str, Any]: + """ + Retrieve all stored recall parameters for a given queue. + + Returns a dictionary of all recall parameters that have been set for the queue. + + Args: + queue_id: The queue ID to retrieve parameters for + + Returns: + A dictionary containing all stored recall parameters + """ + logger = ApiDependencies.invoker.services.logger + + try: + # Retrieve all recall parameters by iterating through expected keys + # Since client_state_persistence doesn't have a "get_all" method, we'll + # return an informative response + return { + "status": "success", + "queue_id": queue_id, + "note": "Use the frontend to access stored recall parameters, or set specific parameters using POST", + } + + except Exception as e: + logger.error(f"Error retrieving recall parameters: {e}") + raise HTTPException( + status_code=500, + detail="Error retrieving recall parameters", + ) diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 188f958c887..9db16aa2d2f 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -35,6 +35,7 @@ QueueClearedEvent, QueueEventBase, QueueItemStatusChangedEvent, + RecallParametersUpdatedEvent, register_events, ) @@ -61,6 +62,7 @@ class BulkDownloadSubscriptionEvent(BaseModel): QueueItemStatusChangedEvent, BatchEnqueuedEvent, QueueClearedEvent, + RecallParametersUpdatedEvent, } MODEL_EVENTS = { diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 335327f532b..e86a397c414 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -24,6 +24,7 @@ images, model_manager, model_relationships, + recall_parameters, session_queue, style_presets, utilities, @@ -133,6 +134,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): app.include_router(workflows.workflows_router, prefix="/api") app.include_router(style_presets.style_presets_router, prefix="/api") app.include_router(client_state.client_state_router, prefix="/api") +app.include_router(recall_parameters.recall_parameters_router, prefix="/api") app.openapi = get_openapi_func(app) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index c70ef3fa16e..4c2d3c5c4cb 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -30,6 +30,7 @@ QueueClearedEvent, QueueItemsRetriedEvent, QueueItemStatusChangedEvent, + RecallParametersUpdatedEvent, ) if TYPE_CHECKING: @@ -110,6 +111,10 @@ def emit_queue_cleared(self, queue_id: str) -> None: """Emitted when a queue is cleared""" self.dispatch(QueueClearedEvent.build(queue_id)) + def emit_recall_parameters_updated(self, queue_id: str, parameters: dict) -> None: + """Emitted when recall parameters are updated""" + self.dispatch(RecallParametersUpdatedEvent.build(queue_id, parameters)) + # endregion # region Download diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index a924f2eed9f..082eb8a6b40 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -646,3 +646,16 @@ def build( bulk_download_item_name=bulk_download_item_name, error=error, ) + + +@payload_schema.register +class RecallParametersUpdatedEvent(QueueEventBase): + """Event model for recall_parameters_updated""" + + __event_name__ = "recall_parameters_updated" + + parameters: dict[str, Any] = Field(description="The recall parameters that were updated") + + @classmethod + def build(cls, queue_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent": + return cls(queue_id=queue_id, parameters=parameters) diff --git a/invokeai/backend/image_util/controlnet_processor.py b/invokeai/backend/image_util/controlnet_processor.py new file mode 100644 index 00000000000..87739f69e1c --- /dev/null +++ b/invokeai/backend/image_util/controlnet_processor.py @@ -0,0 +1,170 @@ +"""Utilities for processing images with ControlNet processors.""" + +from datetime import datetime +from typing import Any, Optional + +from invokeai.app.invocations.fields import ImageField +from invokeai.app.services.invoker import InvocationServices +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.services.shared.graph import Graph, GraphExecutionState +from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context + + +def _get_processor_invocation_class(processor_type: str): + """Get the invocation class for a processor type.""" + # Import processor invocation classes on demand + processor_class_map = { + "canny_image_processor": lambda: __import__( + "invokeai.app.invocations.canny", fromlist=["CannyEdgeDetectionInvocation"] + ).CannyEdgeDetectionInvocation, + "hed_image_processor": lambda: __import__( + "invokeai.app.invocations.hed", fromlist=["HEDEdgeDetectionInvocation"] + ).HEDEdgeDetectionInvocation, + "mlsd_image_processor": lambda: __import__( + "invokeai.app.invocations.mlsd", fromlist=["MLSDDetectionInvocation"] + ).MLSDDetectionInvocation, + "depth_anything_image_processor": lambda: __import__( + "invokeai.app.invocations.depth_anything", fromlist=["DepthAnythingDepthEstimationInvocation"] + ).DepthAnythingDepthEstimationInvocation, + "normalbae_image_processor": lambda: __import__( + "invokeai.app.invocations.normal_bae", fromlist=["NormalMapInvocation"] + ).NormalMapInvocation, + "pidi_image_processor": lambda: __import__( + "invokeai.app.invocations.pidi", fromlist=["PiDiNetEdgeDetectionInvocation"] + ).PiDiNetEdgeDetectionInvocation, + "lineart_image_processor": lambda: __import__( + "invokeai.app.invocations.lineart", fromlist=["LineartEdgeDetectionInvocation"] + ).LineartEdgeDetectionInvocation, + "lineart_anime_image_processor": lambda: __import__( + "invokeai.app.invocations.lineart_anime", fromlist=["LineartAnimeEdgeDetectionInvocation"] + ).LineartAnimeEdgeDetectionInvocation, + "content_shuffle_image_processor": lambda: __import__( + "invokeai.app.invocations.content_shuffle", fromlist=["ContentShuffleInvocation"] + ).ContentShuffleInvocation, + "dw_openpose_image_processor": lambda: __import__( + "invokeai.app.invocations.dw_openpose", fromlist=["DWOpenposeDetectionInvocation"] + ).DWOpenposeDetectionInvocation, + "mediapipe_face_processor": lambda: __import__( + "invokeai.app.invocations.mediapipe_face", fromlist=["MediaPipeFaceDetectionInvocation"] + ).MediaPipeFaceDetectionInvocation, + # Note: zoe_depth_image_processor doesn't have a processor invocation implementation + "color_map_image_processor": lambda: __import__( + "invokeai.app.invocations.color_map", fromlist=["ColorMapInvocation"] + ).ColorMapInvocation, + } + + if processor_type in processor_class_map: + return processor_class_map[processor_type]() + return None + + +# Map processor type names to their default parameters +PROCESSOR_DEFAULT_PARAMS = { + "canny_image_processor": {"low_threshold": 100, "high_threshold": 200}, + "hed_image_processor": {"scribble": False}, + "mlsd_image_processor": {"detect_resolution": 512, "thr_v": 0.1, "thr_d": 0.1}, + "depth_anything_image_processor": {"model_size": "small"}, + "normalbae_image_processor": {"detect_resolution": 512}, + "pidi_image_processor": {"detect_resolution": 512, "safe": False}, + "lineart_image_processor": {"detect_resolution": 512, "coarse": False}, + "lineart_anime_image_processor": {"detect_resolution": 512}, + "content_shuffle": {}, + "dw_openpose_image_processor": {"draw_body": True, "draw_face": True, "draw_hands": True}, + "mediapipe_face_processor": {"max_faces": 1, "min_confidence": 0.5}, + "zoe_depth_image_processor": {}, + "color_map_image_processor": {"color_map_tile_size": 64}, +} + + +def process_controlnet_image(image_name: str, model_key: str, services: InvocationServices) -> Optional[dict[str, Any]]: + """ + Process a controlnet image using the appropriate processor based on the model's default settings. + + Args: + image_name: The filename of the image to process + model_key: The model key to look up default processor settings + services: The invocation services providing access to models and images + + Returns: + A dictionary with the processed image data (image_name, width, height) or None if processing fails + """ + logger = services.logger + + try: + # Get model config to find default processor + model_record = services.model_manager.store.get_model(model_key) + if not model_record or not model_record.default_settings: + logger.info(f"No default processor settings found for model {model_key}") + return None + + preprocessor = model_record.default_settings.preprocessor + if not preprocessor: + logger.info(f"No preprocessor configured for model {model_key}") + return None + + # Get the invocation class for this processor + invocation_class = _get_processor_invocation_class(preprocessor) + if not invocation_class: + logger.info(f"No processor mapping found for preprocessor '{preprocessor}'") + return None + + # Get default parameters for this processor + default_params = PROCESSOR_DEFAULT_PARAMS.get(preprocessor, {}) + logger.info(f"Processing image {image_name} with processor {preprocessor}") + + # Create a minimal context to run the invocation + # We need a fake queue item and session for the context + fake_session = GraphExecutionState(graph=Graph()) + now = datetime.now() + + # Create invocation instance first so we have its ID + invocation_params = {"image": ImageField(image_name=image_name), **default_params} + invocation = invocation_class(**invocation_params) + + # Add the invocation ID to the session's prepared_source_mapping + # This is required for the invocation context to emit progress events + fake_session.prepared_source_mapping[invocation.id] = invocation.id + + fake_queue_item = SessionQueueItem( + item_id=0, + session_id=fake_session.id, + queue_id="default", + batch_id="recall_processor", + field_values=None, + session=fake_session, + status="in_progress", + created_at=now, + updated_at=now, + started_at=now, + completed_at=None, + ) + + context_data = InvocationContextData( + invocation=invocation, + source_invocation_id=invocation.id, + queue_item=fake_queue_item, + ) + + context = build_invocation_context( + data=context_data, + services=services, + is_canceled=lambda: False, + ) + + # Invoke the processor + output = invocation.invoke(context) + + # Get the processed image DTO + processed_image_dto = services.images.get_dto(output.image.image_name) + + logger.info(f"Successfully processed image {image_name} -> {processed_image_dto.image_name}") + + return { + "image_name": processed_image_dto.image_name, + "width": processed_image_dto.width, + "height": processed_image_dto.height, + } + + except Exception as e: + logger.error(f"Error processing controlnet image {image_name}: {e}", exc_info=True) + return None diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImagePreview.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImagePreview.tsx index 0d9bd14955a..84c1b2fc37b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImagePreview.tsx @@ -1,5 +1,5 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; -import { Flex, Icon, IconButton, Image, Skeleton, Text, Tooltip } from '@invoke-ai/ui-library'; +import { Flex, Icon, IconButton, Skeleton, Text, Tooltip } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { round } from 'es-toolkit/compat'; import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity'; @@ -144,15 +144,16 @@ export const RefImagePreview = memo(() => { cursor="pointer" overflow="hidden" > - } - maxW="full" - maxH="full" - /> + {imageDTO ? ( + {imageDTO.image_name} + ) : ( + + )} {isIPAdapterConfig(entity.config) && ( { fineStep={0.01} minStepsBetweenThumbs={1} formatValue={formatPct} - marks + marks={[0, 0.25, 0.5, 0.75, 1]} withThumbTooltip /> diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 868e2aa2b78..312feb6f5a1 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -1778,7 +1778,7 @@ export const { rasterLayerConvertedToRegionalGuidance, // Control layers controlLayerAdded, - // controlLayerRecalled, + controlLayerRecalled, controlLayerConvertedToRasterLayer, controlLayerConvertedToInpaintMask, controlLayerConvertedToRegionalGuidance, diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 1f0464d1cc4..adafcd9dde4 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1958,6 +1958,61 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/recall/{queue_id}": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get Recall Parameters + * @description Retrieve all stored recall parameters for a given queue. + * + * Returns a dictionary of all recall parameters that have been set for the queue. + * + * Args: + * queue_id: The queue ID to retrieve parameters for + * + * Returns: + * A dictionary containing all stored recall parameters + */ + get: operations["get_recall_parameters"]; + put?: never; + /** + * Update Recall Parameters + * @description Update recallable parameters that can be recalled on the frontend. + * + * This endpoint allows updating parameters such as prompt, model, steps, and other + * generation settings. These parameters are stored in client state and can be + * accessed by the frontend to populate UI elements. + * + * Args: + * queue_id: The queue ID to associate these parameters with + * parameters: The RecallParameter object containing the parameters to update + * + * Returns: + * A dictionary containing the updated parameters and status + * + * Example: + * POST /api/v1/recall/{queue_id} + * { + * "positive_prompt": "a beautiful landscape", + * "model": "sd-1.5", + * "steps": 20, + * "cfg_scale": 7.5, + * "width": 512, + * "height": 512, + * "seed": 12345 + * } + */ + post: operations["update_recall_parameters"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; }; export type webhooks = Record; export type components = { @@ -5097,6 +5152,43 @@ export type components = { */ resize_mode?: "just_resize" | "crop_resize" | "fill_resize" | "just_resize_simple"; }; + /** + * ControlNetRecallParameter + * @description ControlNet configuration for recall + */ + ControlNetRecallParameter: { + /** + * Model Name + * @description The name of the ControlNet/T2I Adapter/Control LoRA model + */ + model_name: string; + /** + * Image Name + * @description The filename of the control image in outputs/images + */ + image_name?: string | null; + /** + * Weight + * @description The weight for the control adapter + * @default 1 + */ + weight?: number; + /** + * Begin Step Percent + * @description When the control adapter is first applied (% of total steps) + */ + begin_step_percent?: number | null; + /** + * End Step Percent + * @description When the control adapter is last applied (% of total steps) + */ + end_step_percent?: number | null; + /** + * Control Mode + * @description The control mode (ControlNet only) + */ + control_mode?: ("balanced" | "more_prompt" | "more_control") | null; + }; /** ControlNet_Checkpoint_FLUX_Config */ ControlNet_Checkpoint_FLUX_Config: { /** @@ -9880,6 +9972,48 @@ export type components = { */ type: "ip_adapter_output"; }; + /** + * IPAdapterRecallParameter + * @description IP Adapter configuration for recall + */ + IPAdapterRecallParameter: { + /** + * Model Name + * @description The name of the IP Adapter model + */ + model_name: string; + /** + * Image Name + * @description The filename of the reference image in outputs/images + */ + image_name?: string | null; + /** + * Weight + * @description The weight for the IP Adapter + * @default 1 + */ + weight?: number; + /** + * Begin Step Percent + * @description When the IP Adapter is first applied (% of total steps) + */ + begin_step_percent?: number | null; + /** + * End Step Percent + * @description When the IP Adapter is last applied (% of total steps) + */ + end_step_percent?: number | null; + /** + * Method + * @description The IP Adapter method + */ + method?: ("full" | "style" | "composition") | null; + /** + * Image Influence + * @description FLUX Redux image influence (if model is flux_redux) + */ + image_influence?: ("lowest" | "low" | "medium" | "high" | "highest") | null; + }; /** IPAdapter_Checkpoint_FLUX_Config */ IPAdapter_Checkpoint_FLUX_Config: { /** @@ -14666,6 +14800,29 @@ export type components = { */ weight: number; }; + /** + * LoRARecallParameter + * @description LoRA configuration for recall + */ + LoRARecallParameter: { + /** + * Model Name + * @description The name of the LoRA model + */ + model_name: string; + /** + * Weight + * @description The weight for the LoRA + * @default 0.75 + */ + weight?: number; + /** + * Is Enabled + * @description Whether the LoRA is enabled + * @default true + */ + is_enabled?: boolean; + }; /** * Select LoRA * @description Selects a LoRA model and weight. @@ -20719,6 +20876,160 @@ export type components = { */ type: "range_of_size"; }; + /** + * RecallParameter + * @description Request model for updating recallable parameters. + */ + RecallParameter: { + /** + * Positive Prompt + * @description Positive prompt text + */ + positive_prompt?: string | null; + /** + * Negative Prompt + * @description Negative prompt text + */ + negative_prompt?: string | null; + /** + * Model + * @description Main model name/identifier + */ + model?: string | null; + /** + * Refiner Model + * @description Refiner model name/identifier + */ + refiner_model?: string | null; + /** + * Vae Model + * @description VAE model name/identifier + */ + vae_model?: string | null; + /** + * Scheduler + * @description Scheduler name + */ + scheduler?: string | null; + /** + * Steps + * @description Number of generation steps + */ + steps?: number | null; + /** + * Refiner Steps + * @description Number of refiner steps + */ + refiner_steps?: number | null; + /** + * Cfg Scale + * @description CFG scale for guidance + */ + cfg_scale?: number | null; + /** + * Cfg Rescale Multiplier + * @description CFG rescale multiplier + */ + cfg_rescale_multiplier?: number | null; + /** + * Refiner Cfg Scale + * @description Refiner CFG scale + */ + refiner_cfg_scale?: number | null; + /** + * Guidance + * @description Guidance scale + */ + guidance?: number | null; + /** + * Width + * @description Image width in pixels + */ + width?: number | null; + /** + * Height + * @description Image height in pixels + */ + height?: number | null; + /** + * Seed + * @description Random seed + */ + seed?: number | null; + /** + * Denoise Strength + * @description Denoising strength + */ + denoise_strength?: number | null; + /** + * Refiner Denoise Start + * @description Refiner denoising start + */ + refiner_denoise_start?: number | null; + /** + * Clip Skip + * @description CLIP skip layers + */ + clip_skip?: number | null; + /** + * Seamless X + * @description Enable seamless X tiling + */ + seamless_x?: boolean | null; + /** + * Seamless Y + * @description Enable seamless Y tiling + */ + seamless_y?: boolean | null; + /** + * Refiner Positive Aesthetic Score + * @description Refiner positive aesthetic score + */ + refiner_positive_aesthetic_score?: number | null; + /** + * Refiner Negative Aesthetic Score + * @description Refiner negative aesthetic score + */ + refiner_negative_aesthetic_score?: number | null; + /** + * Loras + * @description List of LoRAs with their weights + */ + loras?: components["schemas"]["LoRARecallParameter"][] | null; + /** + * Control Layers + * @description List of control adapters (ControlNet, T2I Adapter, Control LoRA) with their settings + */ + control_layers?: components["schemas"]["ControlNetRecallParameter"][] | null; + /** + * Ip Adapters + * @description List of IP Adapters with their settings + */ + ip_adapters?: components["schemas"]["IPAdapterRecallParameter"][] | null; + }; + /** + * RecallParametersUpdatedEvent + * @description Event model for recall_parameters_updated + */ + RecallParametersUpdatedEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Queue Id + * @description The ID of the queue + */ + queue_id: string; + /** + * Parameters + * @description The recall parameters that were updated + */ + parameters: { + [key: string]: unknown; + }; + }; /** * Create Rectangle Mask * @description Create a rectangular mask. @@ -30429,4 +30740,76 @@ export interface operations { }; }; }; + get_recall_parameters: { + parameters: { + query?: never; + header?: never; + path: { + /** @description The queue id to retrieve parameters for */ + queue_id: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": { + [key: string]: unknown; + }; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + update_recall_parameters: { + parameters: { + query?: never; + header?: never; + path: { + /** @description The queue id to perform this operation on */ + queue_id: string; + }; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["RecallParameter"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": { + [key: string]: unknown; + }; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; } diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index f998627d26c..677f4d658d1 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -4,8 +4,28 @@ import { socketConnected } from 'app/store/middleware/listenerMiddleware/listene import type { AppStore } from 'app/store/store'; import { deepClone } from 'common/util/deepClone'; import { forEach, isNil, round } from 'es-toolkit/compat'; +import { allEntitiesDeleted, controlLayerRecalled } from 'features/controlLayers/store/canvasSlice'; +import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice'; +import { + heightChanged, + negativePromptChanged, + positivePromptChanged, + setCfgScale, + setSeed, + setSteps, + widthChanged, +} from 'features/controlLayers/store/paramsSlice'; +import { refImagesRecalled } from 'features/controlLayers/store/refImagesSlice'; +import type { + ControlModeV2, + FLUXReduxImageInfluence, + IPMethodV2, + RefImageState, +} from 'features/controlLayers/store/types'; +import { getControlLayerState, getReferenceImageState } from 'features/controlLayers/store/util'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; +import { modelSelected } from 'features/parameters/store/actions'; import ErrorToastDescription, { getTitle } from 'features/toast/ErrorToastDescription'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; @@ -13,6 +33,7 @@ import { LRUCache } from 'lru-cache'; import { Trans } from 'react-i18next'; import type { ApiTagDescription } from 'services/api'; import { api, LIST_ALL_TAG, LIST_TAG } from 'services/api'; +import { imagesApi } from 'services/api/endpoints/images'; import { modelsApi } from 'services/api/endpoints/models'; import { queueApi } from 'services/api/endpoints/queue'; import { buildOnInvocationComplete } from 'services/events/onInvocationComplete'; @@ -453,6 +474,321 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.debug({ data }, 'Queue items retried'); }); + socket.on('recall_parameters_updated', (data) => { + log.debug('Recall parameters updated'); + + // Apply the recall parameters to the store + if (data.parameters) { + let appliedCount = 0; + + // Map the recall parameter names to store actions + if (data.parameters.positive_prompt !== undefined && typeof data.parameters.positive_prompt === 'string') { + dispatch(positivePromptChanged(data.parameters.positive_prompt)); + appliedCount++; + } + if (data.parameters.negative_prompt !== undefined && typeof data.parameters.negative_prompt === 'string') { + dispatch(negativePromptChanged(data.parameters.negative_prompt)); + appliedCount++; + } + if (data.parameters.width !== undefined && typeof data.parameters.width === 'number') { + dispatch(widthChanged({ width: data.parameters.width })); + appliedCount++; + } + if (data.parameters.height !== undefined && typeof data.parameters.height === 'number') { + dispatch(heightChanged({ height: data.parameters.height })); + appliedCount++; + } + if (data.parameters.seed !== undefined && typeof data.parameters.seed === 'number') { + dispatch(setSeed(data.parameters.seed)); + appliedCount++; + } + if (data.parameters.steps !== undefined && typeof data.parameters.steps === 'number') { + dispatch(setSteps(data.parameters.steps)); + appliedCount++; + } + if (data.parameters.cfg_scale !== undefined && typeof data.parameters.cfg_scale === 'number') { + dispatch(setCfgScale(data.parameters.cfg_scale)); + appliedCount++; + } + + // Handle model - requires looking up the full model config + if (data.parameters.model !== undefined && typeof data.parameters.model === 'string') { + dispatch(modelsApi.endpoints.getModelConfig.initiate(data.parameters.model)) + .unwrap() + .then((modelConfig) => { + if (modelConfig.type === 'main') { + dispatch(modelSelected(modelConfig)); + log.debug(`Applied model: ${modelConfig.name}`); + } else { + log.warn(`Model ${data.parameters.model} is not a main model, skipping`); + } + }) + .catch((error) => { + log.error(`Failed to load model ${data.parameters.model}: ${error}`); + }); + appliedCount++; + } + + if (appliedCount > 0) { + log.info(`Applied ${appliedCount} recall parameters to store`); + } + + // Handle LoRAs + if (data.parameters.loras !== undefined && Array.isArray(data.parameters.loras)) { + log.debug(`Processing ${data.parameters.loras.length} LoRA(s)`); + + // Clear existing LoRAs first + dispatch(loraAllDeleted()); + + // Add each LoRA + for (const loraConfig of data.parameters.loras) { + if (loraConfig.model_key && typeof loraConfig.model_key === 'string') { + dispatch(modelsApi.endpoints.getModelConfig.initiate(loraConfig.model_key)) + .unwrap() + .then((modelConfig) => { + if (modelConfig.type === 'lora') { + const lora = { + id: `recalled-${Date.now()}-${Math.random()}`, + model: { + key: modelConfig.key, + hash: modelConfig.hash, + name: modelConfig.name, + base: modelConfig.base, + type: modelConfig.type, + }, + weight: typeof loraConfig.weight === 'number' ? loraConfig.weight : 0.75, + isEnabled: typeof loraConfig.is_enabled === 'boolean' ? loraConfig.is_enabled : true, + }; + dispatch(loraRecalled({ lora })); + log.debug(`Applied LoRA: ${modelConfig.name} (weight: ${lora.weight})`); + } else { + log.warn(`Model ${loraConfig.model_key} is not a LoRA, skipping`); + } + }) + .catch((error) => { + log.error(`Failed to load LoRA ${loraConfig.model_key}: ${error}`); + }); + } + } + log.info(`Initiated loading of ${data.parameters.loras.length} LoRA(s)`); + } + + // Handle Control Layers + if (data.parameters.control_layers !== undefined && Array.isArray(data.parameters.control_layers)) { + log.debug(`Processing ${data.parameters.control_layers.length} control layer(s)`); + + // If the list is explicitly empty, clear all existing control layers + if (data.parameters.control_layers.length === 0) { + dispatch(allEntitiesDeleted()); + log.info('Cleared all control layers'); + } else { + // Replace existing control layers by first clearing them + dispatch(allEntitiesDeleted()); + + // Then add each new control layer + data.parameters.control_layers.forEach( + (controlConfig: { + model_key: string; + weight?: number; + begin_step_percent?: number; + end_step_percent?: number; + control_mode?: ControlModeV2; + image?: { image_name: string; width: number; height: number }; + processed_image?: { image_name: string; width: number; height: number }; + }) => { + if (controlConfig.model_key && typeof controlConfig.model_key === 'string') { + dispatch(modelsApi.endpoints.getModelConfig.initiate(controlConfig.model_key)) + .unwrap() + .then(async (modelConfig) => { + // Pre-fetch the image DTO if an image is provided, to avoid validation errors + let imageObjects: Array<{ + id: string; + type: 'image'; + image: { image_name: string; width: number; height: number }; + }> = []; + if (controlConfig.image?.image_name) { + try { + // Use the processed image if available, otherwise use the original + const imageToUse = controlConfig.processed_image || controlConfig.image; + await dispatch(imagesApi.endpoints.getImageDTO.initiate(imageToUse.image_name)).unwrap(); + // Add the image to the control layer's objects array + imageObjects = [ + { + id: `recalled-image-${Date.now()}-${Math.random()}`, + type: 'image' as const, + image: { + image_name: imageToUse.image_name, + width: imageToUse.width, + height: imageToUse.height, + }, + }, + ]; + if (controlConfig.processed_image) { + log.debug( + `Pre-fetched processed control layer image: ${imageToUse.image_name} (${imageToUse.width}x${imageToUse.height})` + ); + } else { + log.debug( + `Pre-fetched control layer image: ${imageToUse.image_name} (${imageToUse.width}x${imageToUse.height})` + ); + } + } catch (imageError) { + log.warn( + `Could not pre-fetch control layer image ${controlConfig.image.image_name}, continuing without image: ${imageError}` + ); + } + } + + // Build a valid CanvasControlLayerState using helper function + const controlLayerState = getControlLayerState(`recalled-control-${Date.now()}-${Math.random()}`, { + objects: imageObjects, + controlAdapter: { + type: 'controlnet', + model: { + key: modelConfig.key, + hash: modelConfig.hash, + name: modelConfig.name, + base: modelConfig.base, + type: modelConfig.type, + }, + weight: typeof controlConfig.weight === 'number' ? controlConfig.weight : 1.0, + beginEndStepPct: [ + typeof controlConfig.begin_step_percent === 'number' ? controlConfig.begin_step_percent : 0, + typeof controlConfig.end_step_percent === 'number' ? controlConfig.end_step_percent : 1, + ] as [number, number], + controlMode: controlConfig.control_mode || 'balanced', + }, + }); + + dispatch(controlLayerRecalled({ data: controlLayerState })); + log.debug( + `Applied control layer: ${modelConfig.name} (weight: ${controlLayerState.controlAdapter.weight})` + ); + if (imageObjects.length > 0) { + log.info(`Control layer image loaded: ${controlConfig.image?.image_name}`); + } + }) + .catch((error) => { + log.error(`Failed to load control layer ${controlConfig.model_key}: ${error}`); + }); + } + } + ); + log.info(`Initiated loading of ${data.parameters.control_layers.length} control layer(s)`); + } + } + + // Handle IP Adapters as Reference Images + if (data.parameters.ip_adapters !== undefined && Array.isArray(data.parameters.ip_adapters)) { + log.debug(`Processing ${data.parameters.ip_adapters.length} IP adapter(s)`); + + // If the list is explicitly empty, clear existing reference images + if (data.parameters.ip_adapters.length === 0) { + dispatch(refImagesRecalled({ entities: [], replace: true })); + log.info('Cleared all IP adapter reference images'); + } else { + // Build promises for all IP adapters, then dispatch once with replace: true + const ipAdapterPromises = data.parameters.ip_adapters + .filter((cfg) => cfg.model_key && typeof cfg.model_key === 'string') + .map(async (adapterConfig) => { + try { + const modelConfig = await dispatch( + modelsApi.endpoints.getModelConfig.initiate(adapterConfig.model_key!) + ).unwrap(); + + // Pre-fetch the image DTO if an image is provided, to avoid validation errors + if (adapterConfig.image?.image_name) { + try { + await dispatch(imagesApi.endpoints.getImageDTO.initiate(adapterConfig.image.image_name)).unwrap(); + } catch (imageError) { + log.warn( + `Could not pre-fetch image ${adapterConfig.image.image_name}, continuing anyway: ${imageError}` + ); + } + } + + // Build RefImageState using helper function - supports both ip_adapter and flux_redux + const imageData = adapterConfig.image + ? { + original: { + image: { + image_name: adapterConfig.image.image_name, + width: adapterConfig.image.width ?? 512, + height: adapterConfig.image.height ?? 512, + }, + }, + } + : null; + + const isFluxRedux = modelConfig.type === 'flux_redux'; + const refImageState = getReferenceImageState(`recalled-ref-image-${Date.now()}-${Math.random()}`, { + isEnabled: true, + config: isFluxRedux + ? { + type: 'flux_redux', + image: imageData, + model: { + key: modelConfig.key, + hash: modelConfig.hash, + name: modelConfig.name, + base: modelConfig.base, + type: modelConfig.type, + }, + imageInfluence: (adapterConfig.image_influence as FLUXReduxImageInfluence) || 'highest', + } + : { + type: 'ip_adapter', + image: imageData, + model: { + key: modelConfig.key, + hash: modelConfig.hash, + name: modelConfig.name, + base: modelConfig.base, + type: modelConfig.type, + }, + weight: typeof adapterConfig.weight === 'number' ? adapterConfig.weight : 1.0, + beginEndStepPct: [ + typeof adapterConfig.begin_step_percent === 'number' ? adapterConfig.begin_step_percent : 0, + typeof adapterConfig.end_step_percent === 'number' ? adapterConfig.end_step_percent : 1, + ] as [number, number], + method: (adapterConfig.method as IPMethodV2) || 'full', + clipVisionModel: 'ViT-H', + }, + }); + + if (isFluxRedux) { + log.debug(`Built FLUX Redux ref image state: ${modelConfig.name}`); + } else { + log.debug( + `Built IP adapter ref image state: ${modelConfig.name} (weight: ${typeof adapterConfig.weight === 'number' ? adapterConfig.weight : 1.0})` + ); + } + if (adapterConfig.image?.image_name) { + log.debug( + `IP adapter image: outputs/images/${adapterConfig.image.image_name} (${adapterConfig.image.width}x${adapterConfig.image.height})` + ); + } + + return refImageState; + } catch (error) { + log.error(`Failed to load IP adapter ${adapterConfig.model_key}: ${error}`); + return null; + } + }); + + // Wait for all IP adapters to load, then dispatch with replace: true + Promise.all(ipAdapterPromises).then((refImageStates) => { + const validStates = refImageStates.filter((state): state is RefImageState => state !== null); + if (validStates.length > 0) { + dispatch(refImagesRecalled({ entities: validStates, replace: true })); + log.info(`Applied ${validStates.length} IP adapter(s), replacing existing list`); + } + }); + } + } + } + }); + socket.on('bulk_download_started', (data) => { log.debug({ data }, 'Bulk gallery download preparation started'); }); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 7bbdc35e5e8..8937dcc451d 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -29,6 +29,7 @@ export type ServerToClientEvents = { queue_cleared: (payload: S['QueueClearedEvent']) => void; batch_enqueued: (payload: S['BatchEnqueuedEvent']) => void; queue_items_retried: (payload: S['QueueItemsRetriedEvent']) => void; + recall_parameters_updated: (payload: S['RecallParametersUpdatedEvent']) => void; bulk_download_started: (payload: S['BulkDownloadStartedEvent']) => void; bulk_download_complete: (payload: S['BulkDownloadCompleteEvent']) => void; bulk_download_error: (payload: S['BulkDownloadErrorEvent']) => void;