diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index 388c2d1a87cd..fd026f07c923 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -54,6 +54,11 @@ image = pipe( image.save("sd3_hello_world.png") ``` +**Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family: +- [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) +- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large) +- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo) + ## Memory Optimisations for SD3 SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py index 4f32745dae75..1f9c434b39d0 100644 --- a/scripts/convert_sd3_to_diffusers.py +++ b/scripts/convert_sd3_to_diffusers.py @@ -16,10 +16,9 @@ parser = argparse.ArgumentParser() parser.add_argument("--checkpoint_path", type=str) parser.add_argument("--output_path", type=str) -parser.add_argument("--dtype", type=str, default="fp16") +parser.add_argument("--dtype", type=str) args = parser.parse_args() -dtype = torch.float16 if args.dtype == "fp16" else torch.float32 def load_original_checkpoint(ckpt_path): @@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim): return new_weight -def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim): +def convert_sd3_transformer_checkpoint_to_diffusers( + original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm +): converted_state_dict = {} # Positional and patch embeddings. @@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn.ln_k.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.attn.ln_k.weight" + ) + # output projections. converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop( f"joint_blocks.{i}.x_block.attn.proj.weight" @@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay f"joint_blocks.{i}.context_block.attn.proj.bias" ) + # attn2 + if i in dual_attention_layers: + # Q, K, V + sample_q2, sample_k2, sample_v2 = torch.chunk( + original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0 + ) + sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk( + original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias]) + + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.ln_k.weight" + ) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.proj.bias" + ) + # norms. converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop( f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" @@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict): ) +def get_attn2_layers(state_dict): + attn2_layers = [] + for key in state_dict.keys(): + if "attn2." in key: + # Extract the layer number from the key + layer_num = int(key.split(".")[1]) + attn2_layers.append(layer_num) + return tuple(sorted(set(attn2_layers))) + + +def get_pos_embed_max_size(state_dict): + num_patches = state_dict["pos_embed"].shape[1] + pos_embed_max_size = int(num_patches**0.5) + return pos_embed_max_size + + +def get_caption_projection_dim(state_dict): + caption_projection_dim = state_dict["context_embedder.weight"].shape[0] + return caption_projection_dim + + def main(args): original_ckpt = load_original_checkpoint(args.checkpoint_path) + original_dtype = next(iter(original_ckpt.values())).dtype + + # Initialize dtype with a default value + dtype = None + + if args.dtype is None: + dtype = original_dtype + elif args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}") + + if dtype != original_dtype: + print( + f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution." + ) + num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401 - caption_projection_dim = 1536 + + caption_projection_dim = get_caption_projection_dim(original_ckpt) + + # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 + attn2_layers = get_attn2_layers(original_ckpt) + + # sd3.5 use qk norm("rms_norm") + has_qk_norm = any("ln_q" in key for key in original_ckpt.keys()) + + # sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192 + pos_embed_max_size = get_pos_embed_max_size(original_ckpt) converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers( - original_ckpt, num_layers, caption_projection_dim + original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm ) with CTX(): transformer = SD3Transformer2DModel( - sample_size=64, + sample_size=128, patch_size=2, in_channels=16, joint_attention_dim=4096, num_layers=num_layers, caption_projection_dim=caption_projection_dim, - num_attention_heads=24, - pos_embed_max_size=192, + num_attention_heads=num_layers, + pos_embed_max_size=pos_embed_max_size, + qk_norm="rms_norm" if has_qk_norm else None, + dual_attention_layers=attn2_layers, ) if is_accelerate_available(): load_model_dict_into_meta(transformer, converted_transformer_state_dict) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 84db0d061768..02ed1f965abf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,7 +22,7 @@ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding -from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm +from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX logger = logging.get_logger(__name__) @@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module): processing of `context` conditions. """ - def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + context_pre_only: bool = False, + qk_norm: Optional[str] = None, + use_dual_attention: bool = False, + ): super().__init__() + self.use_dual_attention = use_dual_attention self.context_pre_only = context_pre_only context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" - self.norm1 = AdaLayerNormZero(dim) + if use_dual_attention: + self.norm1 = SD35AdaLayerNormZeroX(dim) + else: + self.norm1 = AdaLayerNormZero(dim) if context_norm_type == "ada_norm_continous": self.norm1_context = AdaLayerNormContinuous( @@ -118,12 +130,14 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl raise ValueError( f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" ) + if hasattr(F, "scaled_dot_product_attention"): processor = JointAttnProcessor2_0() else: raise ValueError( "The current PyTorch version does not support the `scaled_dot_product_attention` function." ) + self.attn = Attention( query_dim=dim, cross_attention_dim=None, @@ -134,8 +148,25 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl context_pre_only=context_pre_only, bias=True, processor=processor, + qk_norm=qk_norm, + eps=1e-6, ) + if use_dual_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + else: + self.attn2 = None + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") @@ -159,7 +190,12 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor ): - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + if self.use_dual_attention: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) if self.context_pre_only: norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) @@ -177,6 +213,11 @@ def forward( attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output + if self.use_dual_attention: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d333590982e3..e735c4ee7d17 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -193,7 +193,7 @@ def __init__( self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) else: - raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") if cross_attention_norm is None: self.norm_cross = None @@ -250,6 +250,10 @@ def __init__( elif qk_norm == "rms_norm": self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) else: self.norm_added_q = None self.norm_added_k = None @@ -1050,61 +1054,72 @@ def __call__( ) -> torch.FloatTensor: residual = hidden_states - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] + batch_size = hidden_states.shape[0] # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - # attention - query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) - key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) - value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) - inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # Split the attention outputs. - hidden_states, encoder_hidden_states = ( - hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1] :], - ) + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) - if not attn.context_pre_only: - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - return hidden_states, encoder_hidden_states + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states class PAGJointAttnProcessor2_0: diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 21e9d3cd6fc5..029c147fcbac 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -97,6 +97,40 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ).to(origin_dtype) +class SD35AdaLayerNormZeroX(nn.Module): + r""" + Norm layer adaptive layer norm zero (AdaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.") + + def forward( + self, + hidden_states: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, ...]: + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk( + 9, dim=1 + ) + norm_hidden_states = self.norm(hidden_states) + hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None] + norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 + + class AdaLayerNormZero(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 9376c91d0756..b28350b8ed9c 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -69,6 +69,10 @@ def __init__( pooled_projection_dim: int = 2048, out_channels: int = 16, pos_embed_max_size: int = 96, + dual_attention_layers: Tuple[ + int, ... + ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 + qk_norm: Optional[str] = None, ): super().__init__() default_out_channels = in_channels @@ -97,6 +101,8 @@ def __init__( num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, context_pre_only=i == num_layers - 1, + qk_norm=qk_norm, + use_dual_attention=True if i in dual_attention_layers else False, ) for i in range(self.config.num_layers) ] diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index 2b9084327289..2be4744c5ac4 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -73,6 +73,65 @@ def prepare_init_args_and_inputs_for_common(self): "joint_attention_dim": 32, "pooled_projection_dim": 64, "out_channels": 4, + "pos_embed_max_size": 96, + "dual_attention_layers": (), + "qk_norm": None, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") + def test_set_attn_processor_for_determinism(self): + pass + + +class SD35TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = SD3Transformer2DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + height = width = embedding_dim = 32 + pooled_embedding_dim = embedding_dim * 2 + sequence_length = 154 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 32, + "patch_size": 1, + "in_channels": 4, + "num_layers": 2, + "attention_head_dim": 8, + "num_attention_heads": 4, + "caption_projection_dim": 32, + "joint_attention_dim": 32, + "pooled_projection_dim": 64, + "out_channels": 4, + "pos_embed_max_size": 96, + "dual_attention_layers": (0,), + "qk_norm": "rms_norm", } inputs_dict = self.dummy_input return init_dict, inputs_dict