Skip to content

Commit 0afbe6c

Browse files
committed
Updated docs and types for Lumina pipelines
1 parent 03a8fcf commit 0afbe6c

File tree

3 files changed

+19
-25
lines changed

3 files changed

+19
-25
lines changed

src/diffusers/pipelines/lumina/pipeline_lumina.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import List, Optional, Tuple, Union
2121

2222
import torch
23-
from transformers import PreTrainedModel, PreTrainedTokenizerBase
23+
from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
2424

2525
from ...image_processor import VaeImageProcessor
2626
from ...models import AutoencoderKL
@@ -143,13 +143,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
143143
Args:
144144
vae ([`AutoencoderKL`]):
145145
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
146-
text_encoder ([`PreTrainedModel`]):
147-
Frozen text-encoder. Lumina-T2I uses
148-
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
149-
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
150-
tokenizer (`AutoTokenizer`):
151-
Tokenizer of class
152-
[AutoTokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
146+
text_encoder ([`GemmaPreTrainedModel`]):
147+
Frozen Gemma text-encoder.
148+
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
149+
Gemma tokenizer.
153150
transformer ([`Transformer2DModel`]):
154151
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
155152
scheduler ([`SchedulerMixin`]):
@@ -180,8 +177,8 @@ def __init__(
180177
transformer: LuminaNextDiT2DModel,
181178
scheduler: FlowMatchEulerDiscreteScheduler,
182179
vae: AutoencoderKL,
183-
text_encoder: PreTrainedModel,
184-
tokenizer: PreTrainedTokenizerBase,
180+
text_encoder: GemmaPreTrainedModel,
181+
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
185182
):
186183
super().__init__()
187184

src/diffusers/pipelines/lumina2/pipeline_lumina2.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919
import torch
20-
from transformers import PreTrainedModel, PreTrainedTokenizerBase
20+
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
2121

2222
from ...image_processor import VaeImageProcessor
2323
from ...models import AutoencoderKL
@@ -150,13 +150,10 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
150150
Args:
151151
vae ([`AutoencoderKL`]):
152152
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
153-
text_encoder ([`PreTrainedModel`]):
154-
Frozen text-encoder. Lumina-T2I uses
155-
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
156-
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
157-
tokenizer (`PreTrainedTokenizerBase`):
158-
Tokenizer of class
159-
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
153+
text_encoder ([`Gemma2PreTrainedModel`]):
154+
Frozen Gemma2 text-encoder.
155+
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
156+
Gemma tokenizer.
160157
transformer ([`Transformer2DModel`]):
161158
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
162159
scheduler ([`SchedulerMixin`]):
@@ -172,8 +169,8 @@ def __init__(
172169
transformer: Lumina2Transformer2DModel,
173170
scheduler: FlowMatchEulerDiscreteScheduler,
174171
vae: AutoencoderKL,
175-
text_encoder: PreTrainedModel,
176-
tokenizer: PreTrainedTokenizerBase,
172+
text_encoder: Gemma2PreTrainedModel,
173+
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
177174
):
178175
super().__init__()
179176

tests/pipelines/lumina2/test_pipeline_lumina2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
import torch
5-
from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
5+
from transformers import AutoTokenizer, Gemma2Config, Gemma2ForCausalLM
66

77
from diffusers import (
88
AutoencoderKL,
@@ -81,21 +81,21 @@ def get_dummy_components(self):
8181
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
8282

8383
torch.manual_seed(0)
84-
config = GemmaConfig(
84+
config = Gemma2Config(
8585
head_dim=2,
8686
hidden_size=8,
8787
intermediate_size=37,
8888
num_attention_heads=4,
8989
num_hidden_layers=2,
9090
num_key_value_heads=4,
9191
)
92-
text_encoder = GemmaForCausalLM(config)
92+
text_encoder = Gemma2ForCausalLM(config)
9393

9494
components = {
95-
"transformer": transformer.eval(),
95+
"transformer": transformer,
9696
"vae": vae.eval(),
9797
"scheduler": scheduler,
98-
"text_encoder": text_encoder.eval(),
98+
"text_encoder": text_encoder,
9999
"tokenizer": tokenizer,
100100
}
101101
return components

0 commit comments

Comments
 (0)