@@ -2621,3 +2621,114 @@ def forward(self, image_embeds: List[torch.Tensor]):
26212621 projected_image_embeds .append (image_embed )
26222622
26232623 return projected_image_embeds
2624+
2625+
2626+ class HiDreamImagePooledEmbed (nn .Module ):
2627+ def __init__ (self , text_emb_dim , hidden_size ):
2628+ super ().__init__ ()
2629+ self .pooled_embedder = TimestepEmbedding (in_channels = text_emb_dim , time_embed_dim = hidden_size )
2630+ self .apply (self ._init_weights )
2631+
2632+ def _init_weights (self , m ):
2633+ if isinstance (m , nn .Linear ):
2634+ nn .init .normal_ (m .weight , std = 0.02 )
2635+ if m .bias is not None :
2636+ nn .init .constant_ (m .bias , 0 )
2637+
2638+ def forward (self , pooled_embed ):
2639+ return self .pooled_embedder (pooled_embed )
2640+
2641+
2642+ class HiDreamImageTimestepEmbed (nn .Module ):
2643+ def __init__ (self , hidden_size , frequency_embedding_size = 256 ):
2644+ super ().__init__ ()
2645+ self .time_proj = Timesteps (num_channels = frequency_embedding_size , flip_sin_to_cos = True , downscale_freq_shift = 0 )
2646+ self .timestep_embedder = TimestepEmbedding (in_channels = frequency_embedding_size , time_embed_dim = hidden_size )
2647+ self .apply (self ._init_weights )
2648+
2649+ def _init_weights (self , m ):
2650+ if isinstance (m , nn .Linear ):
2651+ nn .init .normal_ (m .weight , std = 0.02 )
2652+ if m .bias is not None :
2653+ nn .init .constant_ (m .bias , 0 )
2654+
2655+ def forward (self , timesteps , wdtype ):
2656+ t_emb = self .time_proj (timesteps ).to (dtype = wdtype )
2657+ t_emb = self .timestep_embedder (t_emb )
2658+ return t_emb
2659+
2660+
2661+ class HiDreamImageOutEmbed (nn .Module ):
2662+ def __init__ (self , hidden_size , patch_size , out_channels ):
2663+ super ().__init__ ()
2664+ self .norm_final = nn .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 )
2665+ self .linear = nn .Linear (hidden_size , patch_size * patch_size * out_channels , bias = True )
2666+ self .adaLN_modulation = nn .Sequential (nn .SiLU (), nn .Linear (hidden_size , 2 * hidden_size , bias = True ))
2667+ self .apply (self ._init_weights )
2668+
2669+ def _init_weights (self , m ):
2670+ if isinstance (m , nn .Linear ):
2671+ nn .init .zeros_ (m .weight )
2672+ if m .bias is not None :
2673+ nn .init .constant_ (m .bias , 0 )
2674+
2675+ def forward (self , x , adaln_input ):
2676+ shift , scale = self .adaLN_modulation (adaln_input ).chunk (2 , dim = 1 )
2677+ x = self .norm_final (x ) * (1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 )
2678+ x = self .linear (x )
2679+ return x
2680+
2681+
2682+ class HiDreamImagePatchEmbed (nn .Module ):
2683+ def __init__ (
2684+ self ,
2685+ patch_size = 2 ,
2686+ in_channels = 4 ,
2687+ out_channels = 1024 ,
2688+ ):
2689+ super ().__init__ ()
2690+ self .patch_size = patch_size
2691+ self .out_channels = out_channels
2692+ self .proj = nn .Linear (in_channels * patch_size * patch_size , out_channels , bias = True )
2693+ self .apply (self ._init_weights )
2694+
2695+ def _init_weights (self , m ):
2696+ if isinstance (m , nn .Linear ):
2697+ nn .init .xavier_uniform_ (m .weight )
2698+ if m .bias is not None :
2699+ nn .init .constant_ (m .bias , 0 )
2700+
2701+ def forward (self , latent ):
2702+ latent = self .proj (latent )
2703+ return latent
2704+
2705+
2706+ def rope (pos : torch .Tensor , dim : int , theta : int ) -> torch .Tensor :
2707+ assert dim % 2 == 0 , "The dimension must be even."
2708+
2709+ scale = torch .arange (0 , dim , 2 , dtype = torch .float64 , device = pos .device ) / dim
2710+ omega = 1.0 / (theta ** scale )
2711+
2712+ batch_size , seq_length = pos .shape
2713+ out = torch .einsum ("...n,d->...nd" , pos , omega )
2714+ cos_out = torch .cos (out )
2715+ sin_out = torch .sin (out )
2716+
2717+ stacked_out = torch .stack ([cos_out , - sin_out , sin_out , cos_out ], dim = - 1 )
2718+ out = stacked_out .view (batch_size , - 1 , dim // 2 , 2 , 2 )
2719+ return out .float ()
2720+
2721+
2722+ class HiDreamImageEmbedND (nn .Module ):
2723+ def __init__ (self , theta : int , axes_dim : List [int ]):
2724+ super ().__init__ ()
2725+ self .theta = theta
2726+ self .axes_dim = axes_dim
2727+
2728+ def forward (self , ids : torch .Tensor ) -> torch .Tensor :
2729+ n_axes = ids .shape [- 1 ]
2730+ emb = torch .cat (
2731+ [rope (ids [..., i ], self .axes_dim [i ], self .theta ) for i in range (n_axes )],
2732+ dim = - 3 ,
2733+ )
2734+ return emb .unsqueeze (2 )
0 commit comments