|
13 | 13 |
|
14 | 14 |
|
15 | 15 | class FlorenceTemplate(Template): |
16 | | - # loss_scale = 'last_round' |
17 | | - # skip_prompt = False |
| 16 | + # If it's an encoder-decoder architecture, the default settings are |
| 17 | + # loss_scale: 'last_round' and skip_prompt: False. |
18 | 18 | is_encoder_decoder = True |
19 | 19 |
|
20 | 20 | @staticmethod |
@@ -51,28 +51,32 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
51 | 51 | labels = encoded['labels'] |
52 | 52 | if labels is not None: |
53 | 53 | labels = [0] + labels |
54 | | - pixel_values = processor.image_processor( |
55 | | - images, return_tensors='pt')['pixel_values'].to(self.config.torch_dtype) |
56 | | - encoded = { |
57 | | - 'input_ids': input_ids, |
58 | | - 'labels': labels, |
59 | | - 'pixel_values': pixel_values, |
60 | | - } |
| 54 | + if images: |
| 55 | + pixel_values = processor.image_processor( |
| 56 | + images, return_tensors='pt')['pixel_values'].to(self.config.torch_dtype) |
| 57 | + encoded['pixel_values'] = pixel_values |
| 58 | + encoded['input_ids'] = input_ids |
| 59 | + encoded['labels'] = labels |
61 | 60 | return encoded |
62 | 61 |
|
63 | 62 | def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: |
64 | 63 | inputs_embeds = model.get_input_embeddings()(inputs['input_ids']) |
65 | | - image_features = model._encode_image(inputs['pixel_values']) |
66 | | - inputs_embeds, _ = model._merge_input_ids_with_image_features(image_features, inputs_embeds) |
| 64 | + pixel_values = inputs.get('pixel_values') |
| 65 | + if pixel_values is not None: |
| 66 | + image_features = model._encode_image(pixel_values) |
| 67 | + inputs_embeds, inputs['attention_mask'] = model._merge_input_ids_with_image_features( |
| 68 | + image_features, inputs_embeds) |
67 | 69 | return {'inputs_embeds': inputs_embeds} |
68 | 70 |
|
69 | 71 | def decode(self, generate_ids: List[int], **kwargs) -> Any: |
70 | 72 | response = super().decode(generate_ids, **kwargs) |
71 | 73 | template_inputs = kwargs.get('template_inputs') |
72 | 74 | images = template_inputs.images |
| 75 | + image_size = None |
| 76 | + if images: |
| 77 | + image_size = (images[0].width, images[0].height) |
73 | 78 | return json.dumps( |
74 | | - self.processor.post_process_generation( |
75 | | - response, task=template_inputs.query, image_size=(images[0].width, images[0].height))) |
| 79 | + self.processor.post_process_generation(response, task=template_inputs.query, image_size=image_size)) |
76 | 80 |
|
77 | 81 |
|
78 | 82 | register_template( |
|
0 commit comments