Skip to content

Commit 1f32226

Browse files
authored
fix device map (#132)
* fix device map
1 parent 551a85b commit 1f32226

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

internvl_chat/internvl/model/internlm2/modeling_internlm2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1082,13 +1082,16 @@ def forward(
10821082
output = (logits,) + outputs[1:]
10831083
return (loss,) + output if loss is not None else output
10841084

1085-
return CausalLMOutputWithPast(
1085+
device = input_ids.device if input_ids is not None else inputs_embeds.device
1086+
output = CausalLMOutputWithPast(
10861087
loss=loss,
10871088
logits=logits,
10881089
past_key_values=outputs.past_key_values,
10891090
hidden_states=outputs.hidden_states,
10901091
attentions=outputs.attentions,
10911092
)
1093+
output['logits'] = output['logits'].to(device)
1094+
return output
10921095

10931096
def prepare_inputs_for_generation(
10941097
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs

internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def extract_feature(self, pixel_values):
226226
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
227227
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
228228
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
229-
vit_embeds = self.mlp1(vit_embeds)
229+
vit_embeds = self.mlp1(vit_embeds).to(pixel_values.device)
230230
return vit_embeds
231231

232232
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,

0 commit comments

Comments
 (0)