Skip to content

Commit 7d6c7ed

Browse files
authored
Merge pull request #2 from AshishKumar4/feat/video-diffusion
feat: fixed encoding generation flows
2 parents b46fbac + ce52314 commit 7d6c7ed

File tree

4 files changed

+58
-35
lines changed

4 files changed

+58
-35
lines changed

flaxdiff/data/datasets.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def get_dataset_grain(
3939
augmenter = dataset["augmenter"](image_scale, method)
4040

4141
local_batch_size = batch_size // jax.process_count()
42-
model, tokenizer = defaultTextEncodeModel()
42+
# model, tokenizer = defaultTextEncodeModel()
4343

44-
null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
45-
null_labels = np.array(null_labels[0], dtype=np.float16)
46-
null_labels_full = np.array(null_labels_full[0], dtype=np.float16)
44+
# null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
45+
# null_labels = np.array(null_labels[0], dtype=np.float16)
46+
# null_labels_full = np.array(null_labels_full[0], dtype=np.float16)
4747

4848
sampler = pygrain.IndexSampler(
4949
num_records=len(data_source) if count is None else count,
@@ -80,13 +80,13 @@ def get_trainset():
8080
"train_len": len(data_source),
8181
"local_batch_size": local_batch_size,
8282
"global_batch_size": batch_size,
83-
"null_labels": null_labels,
84-
"null_labels_full": null_labels_full,
85-
"model": model,
86-
"tokenizer": tokenizer,
83+
# "null_labels": null_labels,
84+
# "null_labels_full": null_labels_full,
85+
# "model": model,
86+
# "tokenizer": tokenizer,
8787
}
8888

89-
def generate_collate_fn(tokenizer):
89+
def generate_collate_fn():
9090
auto_tokenize = AutoTextTokenizer(tensor_type="np")
9191
def default_collate(batch):
9292
try:
@@ -121,11 +121,11 @@ def get_dataset_online(
121121
):
122122
local_batch_size = batch_size // jax.process_count()
123123

124-
model, tokenizer = defaultTextEncodeModel()
124+
# model, tokenizer = defaultTextEncodeModel()
125125

126-
null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
127-
null_labels = np.array(null_labels[0], dtype=np.float16)
128-
null_labels_full = np.array(null_labels_full[0], dtype=np.float16)
126+
# null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
127+
# null_labels = np.array(null_labels[0], dtype=np.float16)
128+
# null_labels_full = np.array(null_labels_full[0], dtype=np.float16)
129129

130130
sources = onlineDatasetMap[data_name]["source"]
131131
dataloader = OnlineStreamingDataLoader(
@@ -137,7 +137,7 @@ def get_dataset_online(
137137
global_process_count=jax.process_count(),
138138
global_process_index=jax.process_index(),
139139
prefetch=worker_buffer_size,
140-
collate_fn=generate_collate_fn(tokenizer),
140+
collate_fn=generate_collate_fn(),
141141
default_split="train",
142142
)
143143

@@ -173,8 +173,8 @@ def __next__(self):
173173
"train_len": len(dataloader) * jax.process_count(),
174174
"local_batch_size": local_batch_size,
175175
"global_batch_size": batch_size,
176-
"null_labels": null_labels,
177-
"null_labels_full": null_labels_full,
178-
"model": model,
179-
"tokenizer": tokenizer,
176+
# "null_labels": null_labels,
177+
# "null_labels_full": null_labels_full,
178+
# "model": model,
179+
# "tokenizer": tokenizer,
180180
}

flaxdiff/trainer/autoencoder_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from flaxdiff.utils import RandomMarkovState
1515

1616
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17-
17+
from .diffusion_trainer import TrainState
1818
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
1919

2020
class AutoEncoderTrainer(SimpleTrainer):

flaxdiff/trainer/diffusion_trainer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
2121
from flax.training import dynamic_scale as dynamic_scale_lib
22+
from flaxdiff.utils import TextEncoder, ConditioningEncoder
2223

2324
class TrainState(SimpleTrainState):
2425
rngs: jax.random.PRNGKey
@@ -49,6 +50,7 @@ def __init__(self,
4950
name: str = "Diffusion",
5051
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
5152
autoencoder: AutoEncoder = None,
53+
encoder: ConditioningEncoder = None,
5254
**kwargs
5355
):
5456
super().__init__(
@@ -64,6 +66,7 @@ def __init__(self,
6466
self.unconditional_prob = unconditional_prob
6567

6668
self.autoencoder = autoencoder
69+
self.encoder = encoder
6770

6871
def generate_states(
6972
self,
@@ -106,7 +109,7 @@ def generate_states(
106109

107110
return state, best_state
108111

109-
def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
112+
def _define_train_step(self, batch_size):
110113
noise_schedule: NoiseScheduler = self.noise_schedule
111114
model = self.model
112115
model_output_transform = self.model_output_transform
@@ -115,6 +118,11 @@ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
115118

116119
# Determine the number of unconditional samples
117120
num_unconditional = int(batch_size * unconditional_prob)
121+
122+
_, null_labels_full = self.encoder([""])
123+
null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
124+
125+
conditioning_encoder = self.encoder.model
118126

119127
nS, nC = null_labels_seq.shape
120128
null_labels_seq = jnp.broadcast_to(
@@ -146,7 +154,7 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
146154
local_rng_state, rngs = local_rng_state.get_random_key()
147155
images = autoencoder.encode(images, rngs)
148156

149-
output = text_embedder(
157+
output = conditioning_encoder(
150158
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
151159
label_seq = output.last_hidden_state
152160

@@ -231,8 +239,5 @@ def compute_metrics(state: TrainState, expected, pred):
231239
return compute_metrics
232240

233241
def fit(self, data, steps_per_epoch, epochs):
234-
null_labels_full = data['null_labels_full']
235242
local_batch_size = data['local_batch_size']
236-
text_embedder = data['model']
237-
super().fit(data, steps_per_epoch, epochs, {
238-
"batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
243+
super().fit(data, steps_per_epoch, epochs, {"batch_size": local_batch_size})

flaxdiff/utils.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import partial
88
import numpy as np
99
from jax.sharding import Mesh, PartitionSpec as P
10+
from abc import ABC, abstractmethod
1011

1112
class MarkovState(struct.PyTreeNode):
1213
pass
@@ -115,21 +116,38 @@ def _normalize(
115116
mul *= scale
116117
y = mul * x
117118
return jnp.asarray(y, dtype)
118-
119-
119+
120120
@dataclass
121-
class TextEncoder:
121+
class ConditioningEncoder(ABC):
122122
model: nn.Module
123123
tokenizer: Callable
124+
125+
def __call__(self, data):
126+
tokens = self.tokenize(data)
127+
outputs = self.encode_from_tokens(tokens)
128+
return outputs
129+
130+
def encode_from_tokens(self, tokens):
131+
outputs = self.model(input_ids=tokens['input_ids'],
132+
attention_mask=tokens['attention_mask'])
133+
last_hidden_state = outputs.last_hidden_state
134+
return last_hidden_state
124135

125-
def __call__(self, prompts):
126-
# inputs = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="np")
127-
inputs = self.tokenizer(prompts, padding="max_length",
136+
def tokenize(self, data):
137+
tokens = self.tokenizer(data, padding="max_length",
128138
max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
129-
outputs = self.model(input_ids=inputs['input_ids'],
130-
attention_mask=inputs['attention_mask'])
131-
# outputs = infer(inputs['input_ids'], inputs['attention_mask'])
132-
139+
return tokens
140+
141+
@dataclass
142+
class TextEncoder(ConditioningEncoder):
143+
def __call__(self, data):
144+
tokens = self.tokenize(data)
145+
outputs = self.encode_from_tokens(tokens)
146+
return outputs
147+
148+
def encode_from_tokens(self, tokens):
149+
outputs = self.model(input_ids=tokens['input_ids'],
150+
attention_mask=tokens['attention_mask'])
133151
last_hidden_state = outputs.last_hidden_state
134152
pooler_output = outputs.pooler_output # pooled (EOS token) states
135153
embed_pooled = pooler_output # .astype(jnp.float16)

0 commit comments

Comments
 (0)