Skip to content

Commit b7abbdb

Browse files
committed
feat: general changes
1 parent 0d26ac3 commit b7abbdb

File tree

10 files changed

+879
-964
lines changed

10 files changed

+879
-964
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# ![](images/logo.jpeg "FlaxDiff")
22

3+
**This project is being used for the UMD Course project MSML 605: MLOps**
4+
35
**This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
46

57
## A Versatile and simple Diffusion Library

flaxdiff/data/dataloaders.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,14 +291,22 @@ def get_dataset_grain(
291291

292292
local_batch_size = batch_size // jax.process_count()
293293

294-
sampler = pygrain.IndexSampler(
294+
train_sampler = pygrain.IndexSampler(
295295
num_records=len(data_source) if count is None else count,
296296
shuffle=True,
297297
seed=seed,
298298
num_epochs=num_epochs,
299299
shard_options=pygrain.ShardByJaxProcess(),
300300
)
301301

302+
# val_sampler = pygrain.IndexSampler(
303+
# num_records=len(data_source) if count is None else count,
304+
# shuffle=False,
305+
# seed=seed,
306+
# num_epochs=num_epochs,
307+
# shard_options=pygrain.ShardByJaxProcess(),
308+
# )
309+
302310
def get_trainset():
303311
transformations = [
304312
augmenter(),
@@ -307,7 +315,7 @@ def get_trainset():
307315

308316
loader = pygrain.DataLoader(
309317
data_source=data_source,
310-
sampler=sampler,
318+
sampler=train_sampler,
311319
operations=transformations,
312320
worker_count=worker_count,
313321
read_options=pygrain.ReadOptions(
@@ -316,10 +324,31 @@ def get_trainset():
316324
worker_buffer_size=worker_buffer_size,
317325
)
318326
return loader
327+
328+
# def get_valset():
329+
# transformations = [
330+
# augmenter(),
331+
# pygrain.Batch(local_batch_size, drop_remainder=True),
332+
# ]
333+
334+
# loader = pygrain.DataLoader(
335+
# data_source=data_source,
336+
# sampler=val_sampler,
337+
# operations=transformations,
338+
# worker_count=worker_count,
339+
# read_options=pygrain.ReadOptions(
340+
# read_thread_count, read_buffer_size
341+
# ),
342+
# worker_buffer_size=worker_buffer_size,
343+
# )
344+
# return loader
345+
get_valset = get_trainset # For now, use the same function for validation
319346

320347
return {
321348
"train": get_trainset,
322349
"train_len": len(data_source),
350+
"val": get_valset,
351+
"val_len": len(data_source),
323352
"local_batch_size": local_batch_size,
324353
"global_batch_size": batch_size,
325354
}

flaxdiff/data/dataset_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"augmenter": gcs_augmenters,
2222
},
2323
"laiona_coco": {
24-
"source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
24+
"source": data_source_gcs('datasets/laion12m+mscoco'),
2525
"augmenter": gcs_augmenters,
2626
},
2727
"aesthetic_coyo": {

flaxdiff/data/sources/images.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,16 @@ def map(self, element) -> Dict[str, jnp.array]:
167167

168168
return TFDSTransform
169169

170+
"""
171+
Batch structure:
172+
{
173+
"image": image_batch,
174+
"text": {
175+
"input_ids": input_ids_batch,
176+
"attention_mask": attention_mask_batch,
177+
}
178+
179+
"""
170180

171181
# ----------------------------------------------------------------------------------
172182
# GCS Image Source
@@ -248,6 +258,13 @@ def create_transform(self, image_scale: int = 256, method: Any = None) -> Callab
248258
A callable that returns a pygrain.MapTransform.
249259
"""
250260
labelizer = self.labelizer
261+
if method is None:
262+
if image_scale > 256:
263+
method = cv2.INTER_CUBIC
264+
else:
265+
method = cv2.INTER_AREA
266+
267+
print(f"Using method: {method}")
251268

252269
class GCSTransform(pygrain.MapTransform):
253270
def __init__(self, *args, **kwargs):

flaxdiff/trainer/general_diffusion_trainer.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
from flaxdiff.utils import RandomMarkovState, serialize_model, get_latest_checkpoint
1919
from flaxdiff.inputs import ConditioningEncoder, ConditionalInputConfig, DiffusionInputConfig
2020

21-
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
21+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics, convert_to_global_tree
2222

2323
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
2424
from flax.training import dynamic_scale as dynamic_scale_lib
2525

2626
# Reuse the TrainState from the DiffusionTrainer
27-
from flaxdiff.trainer.diffusion_trainer import TrainState, DiffusionTrainer
27+
from .diffusion_trainer import TrainState, DiffusionTrainer
2828
import shutil
2929

3030
def generate_modelname(
@@ -103,6 +103,15 @@ def generate_modelname(
103103
# model_name = f"{model_name}-{config_hash}"
104104
return model_name
105105

106+
@dataclass
107+
class EvaluationMetric:
108+
"""
109+
Evaluation metrics for the diffusion model.
110+
The function is given generated samples batch [B, H, W, C] and the original batch.
111+
"""
112+
function: Callable
113+
name: str
114+
106115
class GeneralDiffusionTrainer(DiffusionTrainer):
107116
"""
108117
General trainer for diffusion models supporting both images and videos.
@@ -126,6 +135,7 @@ def __init__(self,
126135
native_resolution: int = None,
127136
frames_per_sample: int = None,
128137
wandb_config: Dict[str, Any] = None,
138+
eval_metrics: List[EvaluationMetric] = None,
129139
**kwargs
130140
):
131141
"""
@@ -150,6 +160,7 @@ def __init__(self,
150160
autoencoder=autoencoder,
151161
)
152162
self.input_config = input_config
163+
self.eval_metrics = eval_metrics
153164

154165
if wandb_config is not None:
155166
# If input_config is not in wandb_config, add it
@@ -363,7 +374,6 @@ def _define_validation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSamp
363374
def generate_samples(
364375
val_state: TrainState,
365376
batch,
366-
sampler: DiffusionSampler,
367377
diffusion_steps: int,
368378
):
369379
# Process all conditional inputs
@@ -385,7 +395,7 @@ def generate_samples(
385395
model_conditioning_inputs=tuple(model_conditioning_inputs),
386396
)
387397

388-
return sampler, generate_samples
398+
return generate_samples
389399

390400
def _get_image_size(self):
391401
"""Helper to determine image size from available information."""
@@ -415,32 +425,73 @@ def validation_loop(
415425
"""
416426
Run validation and log samples for both image and video diffusion.
417427
"""
418-
sampler, generate_samples = val_step_fn
419-
val_ds = iter(val_ds()) if val_ds else None
428+
global_device_count = jax.device_count()
429+
local_device_count = jax.local_device_count()
430+
process_index = jax.process_index()
431+
generate_samples = val_step_fn
420432

433+
val_ds = iter(val_ds()) if val_ds else None
434+
# Evaluation step
421435
try:
422-
# Generate samples
423-
samples = generate_samples(
424-
val_state,
425-
next(val_ds),
426-
sampler,
427-
diffusion_steps,
428-
)
429-
430-
# Log samples to wandb
431-
if getattr(self, 'wandb', None) is not None and self.wandb:
432-
import numpy as np
436+
metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
437+
for i in range(val_steps_per_epoch):
438+
if val_ds is None:
439+
batch = None
440+
else:
441+
batch = next(val_ds)
442+
if self.distributed_training and global_device_count > 1:
443+
batch = convert_to_global_tree(self.mesh, batch)
444+
# Generate samples
445+
samples = generate_samples(
446+
val_state,
447+
batch,
448+
diffusion_steps,
449+
)
433450

434-
# Process samples differently based on dimensionality
435-
if len(samples.shape) == 5: # [B,T,H,W,C] - Video data
436-
self._log_video_samples(samples, current_step)
437-
else: # [B,H,W,C] - Image data
438-
self._log_image_samples(samples, current_step)
451+
if self.eval_metrics is not None:
452+
for metric in self.eval_metrics:
453+
try:
454+
# Evaluate metrics
455+
metric_val = metric.function(samples, batch)
456+
metrics[metric.name].append(metric_val)
457+
except Exception as e:
458+
print("Error in evaluation metrics:", e)
459+
import traceback
460+
traceback.print_exc()
461+
pass
439462

463+
if i == 0:
464+
print(f"Evaluation started for process index {process_index}")
465+
# Log samples to wandb
466+
if getattr(self, 'wandb', None) is not None and self.wandb:
467+
import numpy as np
468+
469+
# Process samples differently based on dimensionality
470+
if len(samples.shape) == 5: # [B,T,H,W,C] - Video data
471+
self._log_video_samples(samples, current_step)
472+
else: # [B,H,W,C] - Image data
473+
self._log_image_samples(samples, current_step)
474+
475+
if getattr(self, 'wandb', None) is not None and self.wandb:
476+
# metrics is a dict of metrics
477+
if metrics and type(metrics) == dict:
478+
# Flatten the metrics
479+
metrics = {k: np.mean(v) for k, v in metrics.items()}
480+
# Log the metrics
481+
for key, value in metrics.items():
482+
if isinstance(value, jnp.ndarray):
483+
value = np.array(value)
484+
self.wandb.log({
485+
f"val/{key}": value,
486+
}, step=current_step)
487+
488+
except StopIteration:
489+
print(f"Validation dataset exhausted for process index {process_index}")
440490
except Exception as e:
441-
print("Error in validation loop:", e)
491+
print(f"Error during validation for process index {process_index}: {e}")
442492
import traceback
443493
traceback.print_exc()
494+
444495

445496
def _log_video_samples(self, samples, current_step):
446497
"""Helper to log video samples to wandb."""

flaxdiff/trainer/simple_trainer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,9 @@ def train_loop(
411411
train_ds,
412412
train_steps_per_epoch,
413413
current_step,
414-
rng_state
414+
rng_state,
415+
save_every:int=None,
416+
val_every=None,
415417
):
416418
global_device_count = jax.device_count()
417419
process_index = jax.process_index()
@@ -491,8 +493,8 @@ def train_loop(
491493
"train/loss": loss,
492494
}, step=current_step)
493495
# Save the model every few steps
494-
if i % 10000 == 0 and i > 0:
495-
print(f"Saving model after 10000 step {current_step}")
496+
if save_every and i % save_every == 0 and i > 0:
497+
print(f"Saving model after {save_every} step {current_step}")
496498
print(f"Devices: {len(jax.devices())}") # To sync the devices
497499
self.save(current_epoch, current_step, train_state, rng_state)
498500
print(f"Saving done by process index {process_index}")
@@ -518,7 +520,7 @@ def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps
518520
self.validation_loop(
519521
train_state,
520522
val_step,
521-
data.get('test', data.get('val', None)),
523+
data.get('val', data.get('test', None)),
522524
val_steps_per_epoch,
523525
self.latest_step,
524526
)
@@ -569,7 +571,7 @@ def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps
569571
self.validation_loop(
570572
train_state,
571573
val_step,
572-
data.get('test', None),
574+
data.get('val', data.get('test', None)),
573575
val_steps_per_epoch,
574576
current_step,
575577
)

0 commit comments

Comments
 (0)