-
Notifications
You must be signed in to change notification settings - Fork 66
fix fastv bugs #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix fastv bugs #400
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,46 +44,13 @@ def build_model(self): | |
| self.llava_config.use_cache = True | ||
| self.vlm_model_config.use_cache = True | ||
| logger.info(f'self.vlm_model_config : {self.vlm_model_config}') | ||
|
|
||
| self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model( | ||
| self.model_path, | ||
| None, | ||
| get_model_name_from_path(self.model_path), | ||
| load_8bit=False, | ||
| load_4bit=False, | ||
| device='cpu', | ||
| torch_dtype=self.torch_dtype, | ||
| config=self.llava_config, | ||
| ) | ||
|
|
||
| # llava forward not support "cache_position" | ||
| ori_forward = self.vlm_model.forward | ||
|
|
||
| def safe_forward(*args, **kwargs): | ||
| kwargs.pop('cache_position', None) | ||
| return ori_forward(*args, **kwargs) | ||
| self.vlm_model.forward = safe_forward | ||
|
|
||
| # llava generate use "inputs" instead of "input_ids" | ||
| ori_generate = self.vlm_model.generate | ||
|
|
||
| def safe_generate(*args, **kwargs): | ||
| if 'input_ids' in kwargs: | ||
| kwargs['inputs'] = kwargs.pop('input_ids') | ||
| return ori_generate(*args, **kwargs) | ||
| self.vlm_model.generate = safe_generate | ||
|
|
||
| # "attention_mask" is passed via kwargs rather than as an explicit keyword argument. | ||
| ori_prepare_inputs_for_generation = self.vlm_model.prepare_inputs_for_generation | ||
|
|
||
| def safe_prepare_inputs_for_generation( | ||
| self, input_ids, past_key_values=None, | ||
| inputs_embeds=None, attention_mask=None, **kwargs): | ||
| if attention_mask is not None: | ||
| kwargs['attention_mask'] = attention_mask | ||
| return ori_prepare_inputs_for_generation( | ||
| input_ids, past_key_values, inputs_embeds, **kwargs) | ||
| self.vlm_model.prepare_inputs_for_generation = types.MethodType( | ||
| safe_prepare_inputs_for_generation, self.vlm_model | ||
| device_map='cpu', | ||
| attn_implementation='sdpa' | ||
| ) | ||
|
Comment on lines
+47
to
54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The removal of monkey-patching for device_map='cpu',
attn_implementation='sdpa' |
||
|
|
||
| self.eval_name = 'LlavaEval' | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
update_output_attentions_hookis registered as apre_hook, but it manually executes the module's forward pass viamodule.__class__.forward. This will cause the forward pass for this layer to be executed twice: once inside this hook, and a second time by the PyTorch framework after the hook returns. This is a major performance bottleneck.The correct way to get attention scores is to use a
forward_hookwhich runs after theforwardmethod and receives its output. The deletedstore_attention_hookwas the right pattern.If the goal was to get attention scores without affecting the output signature of the layer (which might break subsequent layers), the
forward_hookshould be modified to strip the attention scores from the output tuple before returning.I strongly recommend reverting to the pre-hook/forward-hook pattern to fix this critical performance issue.