Skip to content

Commit c44717b

Browse files
committed
rename photon into prx
1 parent 8417ff5 commit c44717b

File tree

13 files changed

+65
-66
lines changed

13 files changed

+65
-66
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,12 +541,12 @@
541541
title: PAG
542542
- local: api/pipelines/paint_by_example
543543
title: Paint by Example
544-
- local: api/pipelines/prx
545-
title: PRX
546544
- local: api/pipelines/pixart
547545
title: PixArt-α
548546
- local: api/pipelines/pixart_sigma
549547
title: PixArt-Σ
548+
- local: api/pipelines/prx
549+
title: PRX
550550
- local: api/pipelines/qwenimage
551551
title: QwenImage
552552
- local: api/pipelines/sana

src/diffusers/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,9 @@
232232
"MultiControlNetModel",
233233
"OmniGenTransformer2DModel",
234234
"ParallelConfig",
235-
"PRXTransformer2DModel",
236235
"PixArtTransformer2DModel",
237236
"PriorTransformer",
237+
"PRXTransformer2DModel",
238238
"QwenImageControlNetModel",
239239
"QwenImageMultiControlNetModel",
240240
"QwenImageTransformer2DModel",
@@ -516,11 +516,11 @@
516516
"MusicLDMPipeline",
517517
"OmniGenPipeline",
518518
"PaintByExamplePipeline",
519-
"PRXPipeline",
520519
"PIAPipeline",
521520
"PixArtAlphaPipeline",
522521
"PixArtSigmaPAGPipeline",
523522
"PixArtSigmaPipeline",
523+
"PRXPipeline",
524524
"QwenImageControlNetInpaintPipeline",
525525
"QwenImageControlNetPipeline",
526526
"QwenImageEditInpaintPipeline",
@@ -928,9 +928,9 @@
928928
MultiControlNetModel,
929929
OmniGenTransformer2DModel,
930930
ParallelConfig,
931-
PRXTransformer2DModel,
932931
PixArtTransformer2DModel,
933932
PriorTransformer,
933+
PRXTransformer2DModel,
934934
QwenImageControlNetModel,
935935
QwenImageMultiControlNetModel,
936936
QwenImageTransformer2DModel,
@@ -1182,11 +1182,11 @@
11821182
MusicLDMPipeline,
11831183
OmniGenPipeline,
11841184
PaintByExamplePipeline,
1185-
PRXPipeline,
11861185
PIAPipeline,
11871186
PixArtAlphaPipeline,
11881187
PixArtSigmaPAGPipeline,
11891188
PixArtSigmaPipeline,
1189+
PRXPipeline,
11901190
QwenImageControlNetInpaintPipeline,
11911191
QwenImageControlNetPipeline,
11921192
QwenImageEditInpaintPipeline,

src/diffusers/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@
191191
LuminaNextDiT2DModel,
192192
MochiTransformer3DModel,
193193
OmniGenTransformer2DModel,
194-
PRXTransformer2DModel,
195194
PixArtTransformer2DModel,
196195
PriorTransformer,
196+
PRXTransformer2DModel,
197197
QwenImageTransformer2DModel,
198198
SanaTransformer2DModel,
199199
SD3Transformer2DModel,

src/diffusers/models/transformers/transformer_photon.py renamed to src/diffusers/models/transformers/transformer_prx.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
8080
return xq_out.reshape(*xq.shape).type_as(xq)
8181

8282

83-
class PhotonAttnProcessor2_0:
83+
class PRXAttnProcessor2_0:
8484
r"""
85-
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention
85+
Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
8686
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
8787
"""
8888

@@ -91,30 +91,30 @@ class PhotonAttnProcessor2_0:
9191

9292
def __init__(self):
9393
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
94-
raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
94+
raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
9595

9696
def __call__(
9797
self,
98-
attn: "PhotonAttention",
98+
attn: "PRXAttention",
9999
hidden_states: torch.Tensor,
100100
encoder_hidden_states: Optional[torch.Tensor] = None,
101101
attention_mask: Optional[torch.Tensor] = None,
102102
image_rotary_emb: Optional[torch.Tensor] = None,
103103
**kwargs,
104104
) -> torch.Tensor:
105105
"""
106-
Apply Photon attention using PhotonAttention module.
106+
Apply PRX attention using PRXAttention module.
107107
108108
Args:
109-
attn: PhotonAttention module containing projection layers
109+
attn: PRXAttention module containing projection layers
110110
hidden_states: Image tokens [B, L_img, D]
111111
encoder_hidden_states: Text tokens [B, L_txt, D]
112112
attention_mask: Boolean mask for text tokens [B, L_txt]
113113
image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
114114
"""
115115

116116
if encoder_hidden_states is None:
117-
raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
117+
raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
118118

119119
# Project image tokens to Q, K, V
120120
img_qkv = attn.img_qkv_proj(hidden_states)
@@ -190,14 +190,14 @@ def __call__(
190190
return attn_output
191191

192192

193-
class PhotonAttention(nn.Module, AttentionModuleMixin):
193+
class PRXAttention(nn.Module, AttentionModuleMixin):
194194
r"""
195-
Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
196-
Photon's architecture.
195+
PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
196+
PRX's architecture.
197197
"""
198198

199-
_default_processor_cls = PhotonAttnProcessor2_0
200-
_available_processors = [PhotonAttnProcessor2_0]
199+
_default_processor_cls = PRXAttnProcessor2_0
200+
_available_processors = [PRXAttnProcessor2_0]
201201

202202
def __init__(
203203
self,
@@ -251,7 +251,7 @@ def forward(
251251

252252

253253
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
254-
class PhotonEmbedND(nn.Module):
254+
class PRXEmbedND(nn.Module):
255255
r"""
256256
N-dimensional rotary positional embedding.
257257
@@ -347,7 +347,7 @@ def forward(
347347
return tuple(out[:3]), tuple(out[3:])
348348

349349

350-
class PhotonBlock(nn.Module):
350+
class PRXBlock(nn.Module):
351351
r"""
352352
Multimodal transformer block with text–image cross-attention, modulation, and MLP.
353353
@@ -364,7 +364,7 @@ class PhotonBlock(nn.Module):
364364
Attributes:
365365
img_pre_norm (`nn.LayerNorm`):
366366
Pre-normalization applied to image tokens before attention.
367-
attention (`PhotonAttention`):
367+
attention (`PRXAttention`):
368368
Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
369369
image and text tokens.
370370
post_attention_layernorm (`nn.LayerNorm`):
@@ -400,15 +400,15 @@ def __init__(
400400
# Pre-attention normalization for image tokens
401401
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
402402

403-
# PhotonAttention module with built-in projections and norms
404-
self.attention = PhotonAttention(
403+
# PRXAttention module with built-in projections and norms
404+
self.attention = PRXAttention(
405405
query_dim=hidden_size,
406406
heads=num_heads,
407407
dim_head=self.head_dim,
408408
bias=False,
409409
out_bias=False,
410410
eps=1e-6,
411-
processor=PhotonAttnProcessor2_0(),
411+
processor=PRXAttnProcessor2_0(),
412412
)
413413

414414
# mlp
@@ -557,7 +557,7 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
557557
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
558558

559559

560-
class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
560+
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
561561
r"""
562562
Transformer-based 2D model for text to image generation.
563563
@@ -595,7 +595,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
595595
txt_in (`nn.Linear`):
596596
Projection layer for text conditioning.
597597
blocks (`nn.ModuleList`):
598-
Stack of transformer blocks (`PhotonBlock`).
598+
Stack of transformer blocks (`PRXBlock`).
599599
final_layer (`LastLayer`):
600600
Projection layer mapping hidden tokens back to patch outputs.
601601
@@ -661,14 +661,14 @@ def __init__(
661661

662662
self.hidden_size = hidden_size
663663
self.num_heads = num_heads
664-
self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
664+
self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
665665
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
666666
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
667667
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
668668

669669
self.blocks = nn.ModuleList(
670670
[
671-
PhotonBlock(
671+
PRXBlock(
672672
self.hidden_size,
673673
self.num_heads,
674674
mlp_ratio=mlp_ratio,
@@ -702,7 +702,7 @@ def forward(
702702
return_dict: bool = True,
703703
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
704704
r"""
705-
Forward pass of the PhotonTransformer2DModel.
705+
Forward pass of the PRXTransformer2DModel.
706706
707707
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
708708
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,9 +718,9 @@
718718
StableDiffusionXLPAGPipeline,
719719
)
720720
from .paint_by_example import PaintByExamplePipeline
721-
from .prx import PRXPipeline
722721
from .pia import PIAPipeline
723722
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
723+
from .prx import PRXPipeline
724724
from .qwenimage import (
725725
QwenImageControlNetInpaintPipeline,
726726
QwenImageControlNetPipeline,
File renamed without changes.
File renamed without changes.

src/diffusers/pipelines/photon/pipeline_photon.py renamed to src/diffusers/pipelines/prx/pipeline_prx.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
from diffusers.image_processor import PixArtImageProcessor
3131
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
3232
from diffusers.models import AutoencoderDC, AutoencoderKL
33-
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
34-
from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput
33+
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
3534
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
35+
from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
3636
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
3737
from diffusers.utils import (
3838
logging,
@@ -73,7 +73,7 @@
7373

7474

7575
class TextPreprocessor:
76-
"""Text preprocessing utility for PhotonPipeline."""
76+
"""Text preprocessing utility for PRXPipeline."""
7777

7878
def __init__(self):
7979
"""Initialize text preprocessor."""
@@ -203,34 +203,34 @@ def clean_text(self, text: str) -> str:
203203
Examples:
204204
```py
205205
>>> import torch
206-
>>> from diffusers import PhotonPipeline
206+
>>> from diffusers import PRXPipeline
207207
208208
>>> # Load pipeline with from_pretrained
209-
>>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft")
209+
>>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft")
210210
>>> pipe.to("cuda")
211211
212212
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
213213
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
214-
>>> image.save("photon_output.png")
214+
>>> image.save("prx_output.png")
215215
```
216216
"""
217217

218218

219-
class PhotonPipeline(
219+
class PRXPipeline(
220220
DiffusionPipeline,
221221
LoraLoaderMixin,
222222
FromSingleFileMixin,
223223
TextualInversionLoaderMixin,
224224
):
225225
r"""
226-
Pipeline for text-to-image generation using Photon Transformer.
226+
Pipeline for text-to-image generation using PRX Transformer.
227227
228228
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
229229
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
230230
231231
Args:
232-
transformer ([`PhotonTransformer2DModel`]):
233-
The Photon transformer model to denoise the encoded image latents.
232+
transformer ([`PRXTransformer2DModel`]):
233+
The PRX transformer model to denoise the encoded image latents.
234234
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
235235
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
236236
text_encoder ([`T5GemmaEncoder`]):
@@ -248,7 +248,7 @@ class PhotonPipeline(
248248

249249
def __init__(
250250
self,
251-
transformer: PhotonTransformer2DModel,
251+
transformer: PRXTransformer2DModel,
252252
scheduler: FlowMatchEulerDiscreteScheduler,
253253
text_encoder: T5GemmaEncoder,
254254
tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
@@ -257,9 +257,9 @@ def __init__(
257257
):
258258
super().__init__()
259259

260-
if PhotonTransformer2DModel is None:
260+
if PRXTransformer2DModel is None:
261261
raise ImportError(
262-
"PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed."
262+
"PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed."
263263
)
264264

265265
self.text_preprocessor = TextPreprocessor()
@@ -567,7 +567,7 @@ def __call__(
567567
The output format of the generate image. Choose between
568568
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
569569
return_dict (`bool`, *optional*, defaults to `True`):
570-
Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple.
570+
Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple.
571571
use_resolution_binning (`bool`, *optional*, defaults to `True`):
572572
If set to `True`, the requested height and width are first mapped to the closest resolutions using
573573
predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
@@ -585,9 +585,8 @@ def __call__(
585585
Examples:
586586
587587
Returns:
588-
[`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if
589-
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
590-
generated images.
588+
[`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is
589+
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
591590
"""
592591

593592
# 0. Set height and width
@@ -765,4 +764,4 @@ def __call__(
765764
if not return_dict:
766765
return (image,)
767766

768-
return PhotonPipelineOutput(images=image)
767+
return PRXPipelineOutput(images=image)

0 commit comments

Comments
 (0)