@@ -61,14 +61,20 @@ def __init__(self, controller_addr, worker_addr,
6161 else :
6262 self .model_name = model_name
6363
64- self .device = device
6564 logger .info (f'Loading the model { self .model_name } on worker { worker_id } ...' )
6665 from transformers import AutoTokenizer , CLIPImageProcessor
6766
6867 self .tokenizer = AutoTokenizer .from_pretrained (model_path , trust_remote_code = True )
69- self .model = InternVLChatModel .from_pretrained (
70- model_path , load_in_8bit = load_8bit , torch_dtype = torch .float16 ).eval ()
71- if not load_8bit :
68+ if device == 'auto' :
69+ import os
70+ os .environ ["CUDA_LAUNCH_BLOCKING" ] = "1"
71+ # This can make distributed deployment work properly, wonder why
72+ self .model = InternVLChatModel .from_pretrained (
73+ model_path , load_in_8bit = load_8bit , torch_dtype = torch .float16 , device_map = 'auto' ).eval ()
74+ else :
75+ self .model = InternVLChatModel .from_pretrained (
76+ model_path , load_in_8bit = load_8bit , torch_dtype = torch .float16 ).eval ()
77+ if not load_8bit and not device == 'auto' :
7278 self .model = self .model .cuda ()
7379 self .image_size = self .model .config .force_image_size
7480 self .image_processor = CLIPImageProcessor (
@@ -184,7 +190,7 @@ def generate_stream(self, params):
184190 stop_str = params .get ('stop' , None )
185191 do_sample = True if temperature > 0.001 else False
186192 logger .info (f'num_image_tokens: { num_image_tokens } ' )
187- input_ids = tokenizer_image_token (prompt , tokenizer , IMAGE_TOKEN_INDEX , num_image_tokens , return_tensors = 'pt' ).unsqueeze (0 ).to ( self . device )
193+ input_ids = tokenizer_image_token (prompt , tokenizer , IMAGE_TOKEN_INDEX , num_image_tokens , return_tensors = 'pt' ).unsqueeze (0 ).cuda ( )
188194 input_ids [input_ids == IMAGE_TOKEN_INDEX ] = model .img_context_token_id
189195
190196 keywords = [stop_str ]
0 commit comments