@@ -3510,7 +3510,7 @@ def __call__(
35103510 # if attn.group_norm is not None:
35113511 # hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
35123512
3513-
3513+ # batch_size = hidden_states.shape[0]
35143514
35153515 if encoder_hidden_states is None :
35163516 encoder_hidden_states = hidden_states
@@ -3528,10 +3528,10 @@ def __call__(
35283528
35293529 key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
35303530 value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3531-
3531+ """
35323532 assert attn .norm_q is None
35333533 assert attn .norm_k is None
3534- """
3534+
35353535 # if attn.norm_q is not None:
35363536 # query = attn.norm_q(query)
35373537 # if attn.norm_k is not None:
@@ -3557,26 +3557,27 @@ def __call__(
35573557 # # logger.warning(
35583558 # # "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
35593559 # # )
3560- # hidden_states = self.scaled_dot_product_attention_compiled (
3561- # query, key, value
3562- # )
3560+ # hidden_states = self.scaled_dot_product_attention (
3561+ # query, key, value
3562+ # )
35633563
35643564 #*hidden_states = JaxFun.apply(query, key, value)
3565+ import pdb ; pdb .set_trace ()
35653566 hidden_states = JaxFun .apply (hidden_states , encoder_hidden_states , attn .to_q .weight , attn .to_k .weight , attn .to_v .weight , attn .heads )
35663567 hidden_states = hidden_states .to (input_dtype )
35673568
35683569 # linear proj
35693570 hidden_states = attn .to_out [0 ](hidden_states )
35703571 # dropout
3571- hidden_states = attn .to_out [1 ](hidden_states )
3572+ # hidden_states = attn.to_out[1](hidden_states)
35723573
35733574 # if input_ndim == 4:
35743575 # hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
35753576
3576- if attn .residual_connection :
3577- hidden_states = hidden_states + residual
3577+ # if attn.residual_connection:
3578+ # hidden_states = hidden_states + residual
35783579
3579- hidden_states = hidden_states / attn .rescale_output_factor
3580+ # hidden_states = hidden_states / attn.rescale_output_factor
35803581
35813582 return hidden_states
35823583
0 commit comments