@@ -307,8 +307,14 @@ def __init__(
307307 self .added_kv_proj_dim = added_kv_proj_dim
308308 self .added_proj_bias = added_proj_bias
309309
310- self .norm_q = torch .nn .RMSNorm (dim_head , eps = eps , elementwise_affine = elementwise_affine )
311- self .norm_k = torch .nn .RMSNorm (dim_head , eps = eps , elementwise_affine = elementwise_affine )
310+ if DIFFUSERS_ENABLE_HUB_KERNELS :
311+ from ..normalization import RMSNorm
312+
313+ self .norm_q = RMSNorm (dim_head , eps = eps , elementwise_affine = elementwise_affine )
314+ self .norm_k = RMSNorm (dim_head , eps = eps , elementwise_affine = elementwise_affine )
315+ else :
316+ self .norm_q = torch .nn .RMSNorm (dim_head , eps = eps , elementwise_affine = elementwise_affine )
317+ self .norm_k = torch .nn .RMSNorm (dim_head , eps = eps , elementwise_affine = elementwise_affine )
312318 self .to_q = torch .nn .Linear (query_dim , self .inner_dim , bias = bias )
313319 self .to_k = torch .nn .Linear (query_dim , self .inner_dim , bias = bias )
314320 self .to_v = torch .nn .Linear (query_dim , self .inner_dim , bias = bias )
@@ -319,8 +325,14 @@ def __init__(
319325 self .to_out .append (torch .nn .Dropout (dropout ))
320326
321327 if added_kv_proj_dim is not None :
322- self .norm_added_q = torch .nn .RMSNorm (dim_head , eps = eps )
323- self .norm_added_k = torch .nn .RMSNorm (dim_head , eps = eps )
328+ if DIFFUSERS_ENABLE_HUB_KERNELS :
329+ from ..normalization import RMSNorm
330+
331+ self .norm_added_q = RMSNorm (dim_head , eps = eps )
332+ self .norm_added_k = RMSNorm (dim_head , eps = eps )
333+ else :
334+ self .norm_added_q = torch .nn .RMSNorm (dim_head , eps = eps )
335+ self .norm_added_k = torch .nn .RMSNorm (dim_head , eps = eps )
324336 self .add_q_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
325337 self .add_k_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
326338 self .add_v_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
@@ -357,10 +369,11 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
357369
358370 self .norm = AdaLayerNormZeroSingle (dim )
359371 self .proj_mlp = nn .Linear (dim , self .mlp_hidden_dim )
360- if not DIFFUSERS_ENABLE_HUB_KERNELS :
361- self .act_mlp = nn .GELU (approximate = "tanh" )
362- else :
363- self .act_mlp = gelu_tanh_kernel ()
372+ self .act_mlp = nn .GELU (approximate = "tanh" )
373+ # if not DIFFUSERS_ENABLE_HUB_KERNELS:
374+ # self.act_mlp = nn.GELU(approximate="tanh")
375+ # else:
376+ # self.act_mlp = gelu_tanh_kernel()
364377
365378 self .proj_out = nn .Linear (dim + self .mlp_hidden_dim , dim )
366379
0 commit comments