File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed
src/lmms_engine/datasets/processor Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change 33
44import numpy as np
55import torch
6+ from loguru import logger
67from PIL import Image
7- from transformers import Qwen2Tokenizer
8+ from transformers import Qwen2Tokenizer , set_seed
89
910from lmms_engine .mapping_func import register_processor
1011from lmms_engine .models .bagel .data_utils import (
@@ -40,10 +41,22 @@ def __init__(self, config: ProcessorConfig) -> None:
4041 self .vae_cond_dropout_prob = float (extra_kwargs .get ("vae_cond_dropout_prob" , 0.0 ))
4142
4243 self .user_image_as_vae_condition = extra_kwargs .get ("user_image_as_vae_condition" , None )
44+ self .set_random_seed (extra_kwargs .get ("random_seed" , 4396 ))
4345
4446 def build (self ):
4547 self .processor = self ._build_processor ()
4648
49+ def set_random_seed (self , seed : int ):
50+ if torch .distributed .is_initialized ():
51+ world_size = torch .distributed .get_world_size ()
52+ rank = torch .distributed .get_rank ()
53+ else :
54+ world_size = 1
55+ rank = 0
56+ seed = seed * world_size + rank
57+ logger .info (f"Set random seed to { seed } for rank { rank } in world size { world_size } " )
58+ set_seed (seed )
59+
4760 def _build_processor (self ):
4861 extra_kwargs = getattr (self .config , "extra_kwargs" , None ) or {}
4962 processor = Qwen2Tokenizer .from_pretrained (self .config .processor_name )
You can’t perform that action at this time.
0 commit comments