Skip to content

Commit 0d6a092

Browse files
committed
Applied different rnd seed in bagel so that the noise would be sampled differently at different rank
1 parent de490d8 commit 0d6a092

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/lmms_engine/datasets/processor/bagel_processor.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import numpy as np
55
import torch
6+
from loguru import logger
67
from PIL import Image
7-
from transformers import Qwen2Tokenizer
8+
from transformers import Qwen2Tokenizer, set_seed
89

910
from lmms_engine.mapping_func import register_processor
1011
from lmms_engine.models.bagel.data_utils import (
@@ -40,10 +41,22 @@ def __init__(self, config: ProcessorConfig) -> None:
4041
self.vae_cond_dropout_prob = float(extra_kwargs.get("vae_cond_dropout_prob", 0.0))
4142

4243
self.user_image_as_vae_condition = extra_kwargs.get("user_image_as_vae_condition", None)
44+
self.set_random_seed(extra_kwargs.get("random_seed", 4396))
4345

4446
def build(self):
4547
self.processor = self._build_processor()
4648

49+
def set_random_seed(self, seed: int):
50+
if torch.distributed.is_initialized():
51+
world_size = torch.distributed.get_world_size()
52+
rank = torch.distributed.get_rank()
53+
else:
54+
world_size = 1
55+
rank = 0
56+
seed = seed * world_size + rank
57+
logger.info(f"Set random seed to {seed} for rank {rank} in world size {world_size}")
58+
set_seed(seed)
59+
4760
def _build_processor(self):
4861
extra_kwargs = getattr(self.config, "extra_kwargs", None) or {}
4962
processor = Qwen2Tokenizer.from_pretrained(self.config.processor_name)

0 commit comments

Comments
 (0)