@@ -2621,3 +2621,59 @@ def forward(self, image_embeds: List[torch.Tensor]):
26212621 projected_image_embeds .append (image_embed )
26222622
26232623 return projected_image_embeds
2624+
2625+
2626+ class HiDreamImagePooledEmbed (nn .Module ):
2627+ def __init__ (self , text_emb_dim , hidden_size ):
2628+ super ().__init__ ()
2629+ self .pooled_embedder = TimestepEmbedding (in_channels = text_emb_dim , time_embed_dim = hidden_size )
2630+ self .apply (self ._init_weights )
2631+
2632+ def _init_weights (self , m ):
2633+ if isinstance (m , nn .Linear ):
2634+ nn .init .normal_ (m .weight , std = 0.02 )
2635+ if m .bias is not None :
2636+ nn .init .constant_ (m .bias , 0 )
2637+
2638+ def forward (self , pooled_embed ):
2639+ return self .pooled_embedder (pooled_embed )
2640+
2641+
2642+ class HiDreamImageTimestepEmbed (nn .Module ):
2643+ def __init__ (self , hidden_size , frequency_embedding_size = 256 ):
2644+ super ().__init__ ()
2645+ self .time_proj = Timesteps (num_channels = frequency_embedding_size , flip_sin_to_cos = True , downscale_freq_shift = 0 )
2646+ self .timestep_embedder = TimestepEmbedding (in_channels = frequency_embedding_size , time_embed_dim = hidden_size )
2647+ self .apply (self ._init_weights )
2648+
2649+ def _init_weights (self , m ):
2650+ if isinstance (m , nn .Linear ):
2651+ nn .init .normal_ (m .weight , std = 0.02 )
2652+ if m .bias is not None :
2653+ nn .init .constant_ (m .bias , 0 )
2654+
2655+ def forward (self , timesteps , wdtype ):
2656+ t_emb = self .time_proj (timesteps ).to (dtype = wdtype )
2657+ t_emb = self .timestep_embedder (t_emb )
2658+ return t_emb
2659+
2660+
2661+ class HiDreamImageOutEmbed (nn .Module ):
2662+ def __init__ (self , hidden_size , patch_size , out_channels ):
2663+ super ().__init__ ()
2664+ self .norm_final = nn .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 )
2665+ self .linear = nn .Linear (hidden_size , patch_size * patch_size * out_channels , bias = True )
2666+ self .adaLN_modulation = nn .Sequential (nn .SiLU (), nn .Linear (hidden_size , 2 * hidden_size , bias = True ))
2667+ self .apply (self ._init_weights )
2668+
2669+ def _init_weights (self , m ):
2670+ if isinstance (m , nn .Linear ):
2671+ nn .init .zeros_ (m .weight )
2672+ if m .bias is not None :
2673+ nn .init .constant_ (m .bias , 0 )
2674+
2675+ def forward (self , x , adaln_input ):
2676+ shift , scale = self .adaLN_modulation (adaln_input ).chunk (2 , dim = 1 )
2677+ x = self .norm_final (x ) * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 )
2678+ x = self .linear (x )
2679+ return x
0 commit comments