@@ -80,9 +80,9 @@ def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
8080 return xq_out .reshape (* xq .shape ).type_as (xq )
8181
8282
83- class PhotonAttnProcessor2_0 :
83+ class PRXAttnProcessor2_0 :
8484 r"""
85- Processor for implementing Photon -style attention with multi-source tokens and RoPE. Supports multiple attention
85+ Processor for implementing PRX -style attention with multi-source tokens and RoPE. Supports multiple attention
8686 backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
8787 """
8888
@@ -91,30 +91,30 @@ class PhotonAttnProcessor2_0:
9191
9292 def __init__ (self ):
9393 if not hasattr (torch .nn .functional , "scaled_dot_product_attention" ):
94- raise ImportError ("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0." )
94+ raise ImportError ("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0." )
9595
9696 def __call__ (
9797 self ,
98- attn : "PhotonAttention " ,
98+ attn : "PRXAttention " ,
9999 hidden_states : torch .Tensor ,
100100 encoder_hidden_states : Optional [torch .Tensor ] = None ,
101101 attention_mask : Optional [torch .Tensor ] = None ,
102102 image_rotary_emb : Optional [torch .Tensor ] = None ,
103103 ** kwargs ,
104104 ) -> torch .Tensor :
105105 """
106- Apply Photon attention using PhotonAttention module.
106+ Apply PRX attention using PRXAttention module.
107107
108108 Args:
109- attn: PhotonAttention module containing projection layers
109+ attn: PRXAttention module containing projection layers
110110 hidden_states: Image tokens [B, L_img, D]
111111 encoder_hidden_states: Text tokens [B, L_txt, D]
112112 attention_mask: Boolean mask for text tokens [B, L_txt]
113113 image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
114114 """
115115
116116 if encoder_hidden_states is None :
117- raise ValueError ("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens." )
117+ raise ValueError ("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens." )
118118
119119 # Project image tokens to Q, K, V
120120 img_qkv = attn .img_qkv_proj (hidden_states )
@@ -190,14 +190,14 @@ def __call__(
190190 return attn_output
191191
192192
193- class PhotonAttention (nn .Module , AttentionModuleMixin ):
193+ class PRXAttention (nn .Module , AttentionModuleMixin ):
194194 r"""
195- Photon -style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
196- Photon 's architecture.
195+ PRX -style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
196+ PRX 's architecture.
197197 """
198198
199- _default_processor_cls = PhotonAttnProcessor2_0
200- _available_processors = [PhotonAttnProcessor2_0 ]
199+ _default_processor_cls = PRXAttnProcessor2_0
200+ _available_processors = [PRXAttnProcessor2_0 ]
201201
202202 def __init__ (
203203 self ,
@@ -251,7 +251,7 @@ def forward(
251251
252252
253253# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
254- class PhotonEmbedND (nn .Module ):
254+ class PRXEmbedND (nn .Module ):
255255 r"""
256256 N-dimensional rotary positional embedding.
257257
@@ -347,7 +347,7 @@ def forward(
347347 return tuple (out [:3 ]), tuple (out [3 :])
348348
349349
350- class PhotonBlock (nn .Module ):
350+ class PRXBlock (nn .Module ):
351351 r"""
352352 Multimodal transformer block with text–image cross-attention, modulation, and MLP.
353353
@@ -364,7 +364,7 @@ class PhotonBlock(nn.Module):
364364 Attributes:
365365 img_pre_norm (`nn.LayerNorm`):
366366 Pre-normalization applied to image tokens before attention.
367- attention (`PhotonAttention `):
367+ attention (`PRXAttention `):
368368 Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
369369 image and text tokens.
370370 post_attention_layernorm (`nn.LayerNorm`):
@@ -400,15 +400,15 @@ def __init__(
400400 # Pre-attention normalization for image tokens
401401 self .img_pre_norm = nn .LayerNorm (hidden_size , elementwise_affine = False , eps = 1e-6 )
402402
403- # PhotonAttention module with built-in projections and norms
404- self .attention = PhotonAttention (
403+ # PRXAttention module with built-in projections and norms
404+ self .attention = PRXAttention (
405405 query_dim = hidden_size ,
406406 heads = num_heads ,
407407 dim_head = self .head_dim ,
408408 bias = False ,
409409 out_bias = False ,
410410 eps = 1e-6 ,
411- processor = PhotonAttnProcessor2_0 (),
411+ processor = PRXAttnProcessor2_0 (),
412412 )
413413
414414 # mlp
@@ -557,7 +557,7 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
557557 return fold (seq .transpose (1 , 2 ), shape , kernel_size = patch_size , stride = patch_size )
558558
559559
560- class PhotonTransformer2DModel (ModelMixin , ConfigMixin , AttentionMixin ):
560+ class PRXTransformer2DModel (ModelMixin , ConfigMixin , AttentionMixin ):
561561 r"""
562562 Transformer-based 2D model for text to image generation.
563563
@@ -595,7 +595,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
595595 txt_in (`nn.Linear`):
596596 Projection layer for text conditioning.
597597 blocks (`nn.ModuleList`):
598- Stack of transformer blocks (`PhotonBlock `).
598+ Stack of transformer blocks (`PRXBlock `).
599599 final_layer (`LastLayer`):
600600 Projection layer mapping hidden tokens back to patch outputs.
601601
@@ -661,14 +661,14 @@ def __init__(
661661
662662 self .hidden_size = hidden_size
663663 self .num_heads = num_heads
664- self .pe_embedder = PhotonEmbedND (dim = pe_dim , theta = theta , axes_dim = axes_dim )
664+ self .pe_embedder = PRXEmbedND (dim = pe_dim , theta = theta , axes_dim = axes_dim )
665665 self .img_in = nn .Linear (self .in_channels * self .patch_size ** 2 , self .hidden_size , bias = True )
666666 self .time_in = MLPEmbedder (in_dim = 256 , hidden_dim = self .hidden_size )
667667 self .txt_in = nn .Linear (context_in_dim , self .hidden_size )
668668
669669 self .blocks = nn .ModuleList (
670670 [
671- PhotonBlock (
671+ PRXBlock (
672672 self .hidden_size ,
673673 self .num_heads ,
674674 mlp_ratio = mlp_ratio ,
@@ -702,7 +702,7 @@ def forward(
702702 return_dict : bool = True ,
703703 ) -> Union [Tuple [torch .Tensor , ...], Transformer2DModelOutput ]:
704704 r"""
705- Forward pass of the PhotonTransformer2DModel .
705+ Forward pass of the PRXTransformer2DModel .
706706
707707 The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
708708 transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
0 commit comments