@@ -2532,6 +2532,7 @@ def __call__(
25322532 ) -> torch .FloatTensor :
25332533 if image_projection is None :
25342534 raise ValueError ("image_projection is None" )
2535+ print (image_projection , image_projection .shape )
25352536 batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
25362537
25372538 # `sample` projections.
@@ -2603,103 +2604,30 @@ def __call__(
26032604 # IP-adapter
26042605 ip_hidden_states = image_projection
26052606
2606- if ip_adapter_masks is not None :
2607- if not isinstance (ip_adapter_masks , List ):
2608- # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
2609- ip_adapter_masks = list (ip_adapter_masks .unsqueeze (1 ))
2610- if not (len (ip_adapter_masks ) == len (self .scale ) == len (ip_hidden_states )):
2611- raise ValueError (
2612- f"Length of ip_adapter_masks array ({ len (ip_adapter_masks )} ) must match "
2613- f"length of self.scale array ({ len (self .scale )} ) and number of ip_hidden_states "
2614- f"({ len (ip_hidden_states )} )"
2615- )
2616- else :
2617- for index , (mask , scale , ip_state ) in enumerate (
2618- zip (ip_adapter_masks , self .scale , ip_hidden_states )
2619- ):
2620- if not isinstance (mask , torch .Tensor ) or mask .ndim != 4 :
2621- raise ValueError (
2622- "Each element of the ip_adapter_masks array should be a tensor with shape "
2623- "[1, num_images_for_ip_adapter, height, width]."
2624- " Please use `IPAdapterMaskProcessor` to preprocess your mask"
2625- )
2626- if mask .shape [1 ] != ip_state .shape [1 ]:
2627- raise ValueError (
2628- f"Number of masks ({ mask .shape [1 ]} ) does not match "
2629- f"number of ip images ({ ip_state .shape [1 ]} ) at index { index } "
2630- )
2631- if isinstance (scale , list ) and not len (scale ) == mask .shape [1 ]:
2632- raise ValueError (
2633- f"Number of masks ({ mask .shape [1 ]} ) does not match "
2634- f"number of scales ({ len (scale )} ) at index { index } "
2635- )
2636- else :
2637- ip_adapter_masks = [None ] * len (self .scale )
2638-
26392607 ip_query = hidden_states_query_proj
26402608 ip_attn_output = None
26412609 # for ip-adapter
26422610 for current_ip_hidden_states , scale , to_k_ip , to_v_ip , mask in zip (
26432611 ip_hidden_states , self .scale , self .to_k_ip , self .to_v_ip , ip_adapter_masks
26442612 ):
2645- skip = False
2646- if isinstance (scale , list ):
2647- if all (s == 0 for s in scale ):
2648- skip = True
2649- elif scale == 0 :
2650- skip = True
2651- if not skip :
2652- if mask is not None :
2653- if not isinstance (scale , list ):
2654- scale = [scale ] * mask .shape [1 ]
2655-
2656- current_num_images = mask .shape [1 ]
2657- for i in range (current_num_images ):
2658- ip_key = to_k_ip (current_ip_hidden_states [:, i , :, :])
2659- ip_value = to_v_ip (current_ip_hidden_states [:, i , :, :])
2613+ ip_key = to_k_ip (current_ip_hidden_states )
2614+ ip_value = to_v_ip (current_ip_hidden_states )
26602615
2661- ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2662- ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2616+ ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2617+ ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
26632618
2664- # the output of sdp = (batch, num_heads, seq_len, head_dim)
2665- # TODO: add support for attn.scale when we move to Torch 2.1
2666- _current_ip_hidden_states = F .scaled_dot_product_attention (
2667- ip_query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
2668- )
2669-
2670- _current_ip_hidden_states = _current_ip_hidden_states .transpose (1 , 2 ).reshape (
2671- batch_size , - 1 , attn .heads * head_dim
2672- )
2673- _current_ip_hidden_states = _current_ip_hidden_states .to (ip_query .dtype )
2674-
2675- mask_downsample = IPAdapterMaskProcessor .downsample (
2676- mask [:, i , :, :],
2677- batch_size ,
2678- _current_ip_hidden_states .shape [1 ],
2679- _current_ip_hidden_states .shape [2 ],
2680- )
2681-
2682- mask_downsample = mask_downsample .to (dtype = ip_query .dtype , device = ip_query .device )
2683- hidden_states = hidden_states + scale [i ] * (_current_ip_hidden_states * mask_downsample )
2684- else :
2685- ip_key = to_k_ip (current_ip_hidden_states )
2686- ip_value = to_v_ip (current_ip_hidden_states )
2687-
2688- ip_key = ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2689- ip_value = ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2690-
2691- # the output of sdp = (batch, num_heads, seq_len, head_dim)
2692- # TODO: add support for attn.scale when we move to Torch 2.1
2693- current_ip_hidden_states = F .scaled_dot_product_attention (
2694- ip_query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
2695- )
2696-
2697- current_ip_hidden_states = current_ip_hidden_states .transpose (1 , 2 ).reshape (
2698- batch_size , - 1 , attn .heads * head_dim
2699- )
2700- current_ip_hidden_states = current_ip_hidden_states .to (ip_query .dtype )
2619+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2620+ # TODO: add support for attn.scale when we move to Torch 2.1
2621+ ip_attn_output = F .scaled_dot_product_attention (
2622+ ip_query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False
2623+ )
27012624
2702- ip_attn_output = scale * current_ip_hidden_states
2625+ ip_attn_output = ip_attn_output .transpose (1 , 2 ).reshape (
2626+ batch_size , - 1 , attn .heads * head_dim
2627+ )
2628+ ip_attn_output = scale * ip_attn_output
2629+ print (ip_attn_output )
2630+ ip_attn_output = ip_attn_output .to (ip_query .dtype )
27032631
27042632 return hidden_states , encoder_hidden_states , ip_attn_output
27052633 else :
0 commit comments