@@ -147,6 +147,7 @@ def __init__(self, device="cpu", max_length=77,
147147 self .layer_norm_hidden_state = layer_norm_hidden_state
148148 self .return_projected_pooled = return_projected_pooled
149149 self .return_attention_masks = return_attention_masks
150+ self .execution_device = None
150151
151152 if layer == "hidden" :
152153 assert layer_idx is not None
@@ -163,6 +164,7 @@ def freeze(self):
163164 def set_clip_options (self , options ):
164165 layer_idx = options .get ("layer" , self .layer_idx )
165166 self .return_projected_pooled = options .get ("projected_pooled" , self .return_projected_pooled )
167+ self .execution_device = options .get ("execution_device" , self .execution_device )
166168 if isinstance (self .layer , list ) or self .layer == "all" :
167169 pass
168170 elif layer_idx is None or abs (layer_idx ) > self .num_layers :
@@ -175,6 +177,7 @@ def reset_clip_options(self):
175177 self .layer = self .options_default [0 ]
176178 self .layer_idx = self .options_default [1 ]
177179 self .return_projected_pooled = self .options_default [2 ]
180+ self .execution_device = None
178181
179182 def process_tokens (self , tokens , device ):
180183 end_token = self .special_tokens .get ("end" , None )
@@ -258,7 +261,11 @@ def process_tokens(self, tokens, device):
258261 return torch .cat (embeds_out ), torch .tensor (attention_masks , device = device , dtype = torch .long ), num_tokens , embeds_info
259262
260263 def forward (self , tokens ):
261- device = self .transformer .get_input_embeddings ().weight .device
264+ if self .execution_device is None :
265+ device = self .transformer .get_input_embeddings ().weight .device
266+ else :
267+ device = self .execution_device
268+
262269 embeds , attention_mask , num_tokens , embeds_info = self .process_tokens (tokens , device )
263270
264271 attention_mask_model = None
0 commit comments