Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
74b8389
Add LLM-powered prompt expansion and image-to-prompt features
Pfannkuchensack Feb 23, 2026
af1d87b
chore fix windows paths
Pfannkuchensack Feb 23, 2026
c027bcc
Merge branch 'main' into feature/llm-prompt-tools
JPPhoto Feb 23, 2026
1eff89d
Fix device mismatch for LLM inference and add CPU-only toggle for Tex…
Pfannkuchensack Feb 23, 2026
593693c
Harden LLM endpoints and add tests
Pfannkuchensack Feb 23, 2026
370bdd8
Add Ctrl+Z undo for LLM prompt changes
Pfannkuchensack Feb 23, 2026
93c2d79
Add documentation and What's New entry for LLM prompt tools
Pfannkuchensack Feb 24, 2026
98f8655
Merge remote-tracking branch 'origin/main' into feature/llm-prompt-tools
Pfannkuchensack Feb 28, 2026
506c5e1
fix: resolve merge conflict in mkdocs.yml nav
Pfannkuchensack Feb 28, 2026
e8852e0
Merge branch 'main' into feature/llm-prompt-tools
Pfannkuchensack Feb 28, 2026
c4547f5
Merge branch 'main' into feature/llm-prompt-tools
Pfannkuchensack Mar 3, 2026
b9dbf7b
feat(ui): allow dragging gallery images onto prompt box for Image to …
Pfannkuchensack Mar 5, 2026
a48dd39
Merge branch 'main' into feature/llm-prompt-tools
Pfannkuchensack Mar 15, 2026
1b8e5f6
chore typegen
Pfannkuchensack Mar 15, 2026
43b1bcc
Merge branch 'main' into feature/llm-prompt-tools
Pfannkuchensack Mar 15, 2026
51fafe3
Merge branch 'main' into feature/llm-prompt-tools
Pfannkuchensack Mar 24, 2026
a6d5b85
Merge branch 'main' into feature/llm-prompt-tools
JPPhoto Mar 24, 2026
4435e56
Fix typo in Z-Image Turbo diversity description
Pfannkuchensack Mar 24, 2026
75acb9d
Merge branch 'main' into feature/llm-prompt-tools
Pfannkuchensack Mar 26, 2026
787d4b1
Fix three bugs in LLM/VLM utility endpoints
Pfannkuchensack Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 148 additions & 1 deletion invokeai/app/api/routers/utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
import asyncio
import logging
from pathlib import Path
from typing import Optional, Union

import torch
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from fastapi import Body
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from pydantic import BaseModel
from pyparsing import ParseException
from transformers import AutoProcessor, AutoTokenizer, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor

from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
from invokeai.backend.text_llm_pipeline import DEFAULT_SYSTEM_PROMPT, TextLLMPipeline
from invokeai.backend.util.devices import TorchDevice

logger = logging.getLogger(__name__)

utilities_router = APIRouter(prefix="/v1/utilities", tags=["utilities"])

Expand Down Expand Up @@ -42,3 +54,138 @@ async def parse_dynamicprompts(
prompts = [prompt]
error = str(e)
return DynamicPromptsResponse(prompts=prompts if prompts else [""], error=error)


# --- Expand Prompt ---


class ExpandPromptRequest(BaseModel):
prompt: str
model_key: str
max_tokens: int = 300
system_prompt: str | None = None


class ExpandPromptResponse(BaseModel):
expanded_prompt: str
error: str | None = None


def _resolve_model_path(model_config_path: str) -> Path:
"""Resolve a model config path to an absolute path."""
model_path = Path(model_config_path)
if model_path.is_absolute():
return model_path.resolve()
base_models_path = ApiDependencies.invoker.services.configuration.models_path
return (base_models_path / model_path).resolve()


def _run_expand_prompt(prompt: str, model_key: str, max_tokens: int, system_prompt: str | None) -> str:
"""Run text LLM inference synchronously (called from thread)."""
model_manager = ApiDependencies.invoker.services.model_manager
model_config = model_manager.store.get_model(model_key)
loaded_model = model_manager.load.load_model(model_config)

with loaded_model.model_on_device() as (_, model):
model_abs_path = _resolve_model_path(model_config.path)
tokenizer = AutoTokenizer.from_pretrained(model_abs_path, local_files_only=True)

pipeline = TextLLMPipeline(model, tokenizer)
output = pipeline.run(
prompt=prompt,
system_prompt=system_prompt or DEFAULT_SYSTEM_PROMPT,
max_new_tokens=max_tokens,
device=TorchDevice.choose_torch_device(),
dtype=TorchDevice.choose_torch_dtype(),
)

return output


@utilities_router.post(
"/expand-prompt",
operation_id="expand_prompt",
responses={
200: {"model": ExpandPromptResponse},
},
)
async def expand_prompt(body: ExpandPromptRequest) -> ExpandPromptResponse:
"""Expand a brief prompt into a detailed image generation prompt using a text LLM."""
try:
with torch.no_grad():
expanded = await asyncio.to_thread(
_run_expand_prompt,
body.prompt,
body.model_key,
body.max_tokens,
body.system_prompt,
)
return ExpandPromptResponse(expanded_prompt=expanded)
except Exception as e:
logger.error(f"Error expanding prompt: {e}")
raise HTTPException(status_code=500, detail=str(e))


# --- Image to Prompt ---


class ImageToPromptRequest(BaseModel):
image_name: str
model_key: str
instruction: str = "Describe this image in detail for use as an AI image generation prompt."


class ImageToPromptResponse(BaseModel):
prompt: str
error: str | None = None


def _run_image_to_prompt(image_name: str, model_key: str, instruction: str) -> str:
"""Run LLaVA OneVision inference synchronously (called from thread)."""
model_manager = ApiDependencies.invoker.services.model_manager
model_config = model_manager.store.get_model(model_key)
loaded_model = model_manager.load.load_model(model_config)

# Load the image from InvokeAI's image store
image = ApiDependencies.invoker.services.images.get_pil_image(image_name)
image = image.convert("RGB")

with loaded_model.model_on_device() as (_, model):
assert isinstance(model, LlavaOnevisionForConditionalGeneration)

model_abs_path = _resolve_model_path(model_config.path)
processor = AutoProcessor.from_pretrained(model_abs_path, local_files_only=True)
assert isinstance(processor, LlavaOnevisionProcessor)

pipeline = LlavaOnevisionPipeline(model, processor)
output = pipeline.run(
prompt=instruction,
images=[image],
device=TorchDevice.choose_torch_device(),
dtype=TorchDevice.choose_torch_dtype(),
)

return output


@utilities_router.post(
"/image-to-prompt",
operation_id="image_to_prompt",
responses={
200: {"model": ImageToPromptResponse},
},
)
async def image_to_prompt(body: ImageToPromptRequest) -> ImageToPromptResponse:
"""Generate a descriptive prompt from an image using a vision-language model."""
try:
with torch.no_grad():
prompt = await asyncio.to_thread(
_run_image_to_prompt,
body.image_name,
body.model_key,
body.instruction,
)
return ImageToPromptResponse(prompt=prompt)
except Exception as e:
logger.error(f"Error generating prompt from image: {e}")
raise HTTPException(status_code=500, detail=str(e))
1 change: 1 addition & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class FieldDescriptions:
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
flux_redux_conditioning = "FLUX Redux conditioning tensor"
vllm_model = "The VLLM model to use"
text_llm_model = "The text language model to use for text generation"
flux_fill_conditioning = "FLUX Fill conditioning tensor"
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"

Expand Down
64 changes: 64 additions & 0 deletions invokeai/app/invocations/text_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
from transformers import AutoTokenizer

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, InputField, UIComponent
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import ModelType
from invokeai.backend.text_llm_pipeline import DEFAULT_SYSTEM_PROMPT, TextLLMPipeline
from invokeai.backend.util.devices import TorchDevice


@invocation(
"text_llm",
title="Text LLM",
tags=["llm", "text", "prompt"],
category="llm",
version="1.0.0",
classification=Classification.Beta,
)
class TextLLMInvocation(BaseInvocation):
"""Run a text language model to generate or expand text (e.g. for prompt expansion)."""

prompt: str = InputField(
default="",
description="Input text prompt.",
ui_component=UIComponent.Textarea,
)
system_prompt: str = InputField(
default=DEFAULT_SYSTEM_PROMPT,
description="System prompt that guides the model's behavior.",
ui_component=UIComponent.Textarea,
)
text_llm_model: ModelIdentifierField = InputField(
title="Text LLM Model",
description=FieldDescriptions.text_llm_model,
ui_model_type=ModelType.TextLLM,
)
max_tokens: int = InputField(
default=300,
ge=1,
le=2048,
description="Maximum number of tokens to generate.",
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
model_config = context.models.get_config(self.text_llm_model)

with context.models.load(self.text_llm_model).model_on_device() as (_, model):
model_abs_path = context.models.get_absolute_path(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_abs_path, local_files_only=True)

pipeline = TextLLMPipeline(model, tokenizer)
output = pipeline.run(
prompt=self.prompt,
system_prompt=self.system_prompt,
max_new_tokens=self.max_tokens,
device=TorchDevice.choose_torch_device(),
dtype=TorchDevice.choose_torch_dtype(),
)

return StringOutput(value=output)
2 changes: 2 additions & 0 deletions invokeai/backend/model_manager/configs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
T2IAdapter_Diffusers_SDXL_Config,
)
from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config
from invokeai.backend.model_manager.configs.text_llm import TextLLM_Diffusers_Config
from invokeai.backend.model_manager.configs.textual_inversion import (
TI_File_SD1_Config,
TI_File_SD2_Config,
Expand Down Expand Up @@ -248,6 +249,7 @@
Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()],
Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()],
Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()],
Annotated[TextLLM_Diffusers_Config, TextLLM_Diffusers_Config.get_tag()],
# Unknown model (fallback)
Annotated[Unknown_Config, Unknown_Config.get_tag()],
],
Expand Down
44 changes: 44 additions & 0 deletions invokeai/backend/model_manager/configs/text_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import (
Literal,
Self,
)

from pydantic import Field
from typing_extensions import Any

from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
get_class_name_from_config_dict_or_raise,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelType,
)


class TextLLM_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for text-only causal language models (e.g. Llama, Phi, Qwen, Mistral)."""

type: Literal[ModelType.TextLLM] = Field(default=ModelType.TextLLM)
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)

raise_for_override_fields(cls, override_fields)

# Check that the model's architecture is a causal language model.
# This covers LlamaForCausalLM, PhiForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM,
# MistralForCausalLM, GemmaForCausalLM, GPTNeoXForCausalLM, etc.
class_name = get_class_name_from_config_dict_or_raise(common_config_paths(mod.path))
if not class_name.endswith("ForCausalLM"):
raise NotAMatchError(f"model architecture '{class_name}' is not a causal language model")

return cls(**override_fields)
26 changes: 26 additions & 0 deletions invokeai/backend/model_manager/load/model_loaders/text_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pathlib import Path
from typing import Optional

from transformers import AutoModelForCausalLM

from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextLLM, format=ModelFormat.Diffusers)
class TextLLMModelLoader(ModelLoader):
"""Class for loading text causal language models (Llama, Phi, Qwen, Mistral, etc.)."""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("Unexpected submodel requested for TextLLM model.")

model_path = Path(config.path)
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True, torch_dtype=self._torch_dtype)
return model
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,11 @@ def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyMod

for s in model_info.siblings or []:
assert s.rfilename is not None
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(repo_id, s.rfilename, revision=variant or "main"),
path=Path(name, s.rfilename),
size=s.size,
size=s.size or 0,
sha256=s.lfs.get("sha256") if s.lfs else None,
)
)
Expand Down
1 change: 1 addition & 0 deletions invokeai/backend/model_manager/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class ModelType(str, Enum):
SigLIP = "siglip"
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"
TextLLM = "text_llm"
Unknown = "unknown"


Expand Down
Loading