@@ -45,8 +45,9 @@ def load_prompts(dataset_path):
4545
4646
4747class Model :
48- def __init__ (self , model_path , device , config , prompts , fixed_latent = None ):
48+ def __init__ (self , model_path , device , config , prompts , fixed_latent = None , rank = 0 ):
4949 self .device = device
50+ self .rank = rank
5051 self .height = config ['height' ]
5152 self .width = config ['width' ]
5253 self .num_frames = config ['num_frames' ]
@@ -104,7 +105,7 @@ def flush_queries(self):
104105
105106
106107class DebugModel :
107- def __init__ (self , model_path , device , config , prompts , fixed_latent = None ):
108+ def __init__ (self , model_path , device , config , prompts , fixed_latent = None , rank = 0 ):
108109 self .prompts = prompts
109110
110111 def issue_queries (self , query_samples ):
@@ -251,8 +252,8 @@ def run_mlperf(args, config):
251252 logging .info ("No fixed latent provided - using random initial latents" )
252253
253254 # Loading model
254- model = Model (args .model_path , device , config , dataset , fixed_latent )
255- #model = DebugModel(args.model_path, device, config, dataset, fixed_latent)
255+ model = Model (args .model_path , device , config , dataset , fixed_latent , rank )
256+ #model = DebugModel(args.model_path, device, config, dataset, fixed_latent, rank )
256257 logging .info ("Model loaded successfully!" )
257258
258259 # Prepare loadgen for run
0 commit comments