|
1 | 1 | # Copyright (c) Alibaba, Inc. and its affiliates. |
2 | 2 | import inspect |
| 3 | +import os |
3 | 4 | import re |
4 | 5 | from contextlib import contextmanager |
5 | 6 | from copy import deepcopy |
6 | 7 | from functools import partial, wraps |
7 | 8 | from types import MethodType |
8 | | -from typing import Any, Dict, List, Literal, Optional, Tuple, Union |
| 9 | +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union |
9 | 10 |
|
10 | 11 | import json |
11 | 12 | import torch |
@@ -1539,6 +1540,28 @@ class Llama3Template(Llama3TemplateMixin, Template): |
1539 | 1540 | Template(['<s>'], ['<|User|>:{{QUERY}}\n<|Bot|>:'], ['<eoa>\n'], ['<eoa>'], INTERNLM_SYSTEM, |
1540 | 1541 | ['<s><|System|>:{{SYSTEM}}\n'])) |
1541 | 1542 |
|
| 1543 | +_T = TypeVar('_T') |
| 1544 | + |
| 1545 | +_log_set = set() # log once |
| 1546 | + |
| 1547 | + |
| 1548 | +def get_env_args(args_name: str, |
| 1549 | + type_func: Callable[[str], _T] = int, |
| 1550 | + default_value: Optional[_T] = None) -> Optional[_T]: |
| 1551 | + args_name_upper = args_name.upper() |
| 1552 | + value = os.getenv(args_name_upper) |
| 1553 | + if value is None: |
| 1554 | + value = default_value |
| 1555 | + log_info = (f'Setting {args_name}: {default_value}. ' |
| 1556 | + f'You can adjust this hyperparameter through the environment variable: `{args_name_upper}`.') |
| 1557 | + else: |
| 1558 | + value = type_func(value) |
| 1559 | + log_info = f'Using environment variable `{args_name_upper}`, Setting {args_name}: {value}.' |
| 1560 | + if log_info not in _log_set: |
| 1561 | + _log_set.add(log_info) |
| 1562 | + logger.info(log_info) |
| 1563 | + return value |
| 1564 | + |
1542 | 1565 |
|
1543 | 1566 | class Internlm2Template(ChatmlTemplate): |
1544 | 1567 | system = INTERNLM_SYSTEM |
@@ -1595,12 +1618,14 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An |
1595 | 1618 |
|
1596 | 1619 | if self.version == 'v2.5': |
1597 | 1620 | hd_num = 24 |
1598 | | - Image_transform = get_class_from_dynamic_module('ixc_utils.Image_transform', self.tokenizer.model_dir) |
1599 | 1621 | if len(images) > 1: |
1600 | 1622 | hd_num = 6 |
| 1623 | + hd_num = get_env_args('hd_num', int, hd_num) |
| 1624 | + Image_transform = get_class_from_dynamic_module('ixc_utils.Image_transform', self.tokenizer.model_dir) |
1601 | 1625 | images = [Image_transform(image, hd_num=hd_num) for image in images] |
1602 | 1626 | elif self.version == 'v2-4khd': |
1603 | 1627 | hd_num = 55 |
| 1628 | + hd_num = get_env_args('hd_num', int, hd_num) |
1604 | 1629 | HD_transform = get_class_from_dynamic_module('ixc_utils.HD_transform', self.tokenizer.model_dir) |
1605 | 1630 | images = [HD_transform(image, hd_num=hd_num) for image in images] |
1606 | 1631 | images = [self.model.vis_processor(image).to(dtype) for image in images] |
@@ -1723,7 +1748,9 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An |
1723 | 1748 | images = example.get('images') |
1724 | 1749 | if images: |
1725 | 1750 | labels = inputs.get('labels') |
1726 | | - pixel_values_images = [transform_image(image) for image in images] |
| 1751 | + input_size = get_env_args('input_size', int, 448) |
| 1752 | + max_num = get_env_args('max_num', int, 12) |
| 1753 | + pixel_values_images = [transform_image(image, input_size, max_num) for image in images] |
1727 | 1754 | pixel_values = torch.cat(pixel_values_images, dim=0).to(self.model.dtype) |
1728 | 1755 | image_bs = pixel_values.shape[0] |
1729 | 1756 |
|
@@ -1784,7 +1811,8 @@ def replace_tag(self, media_type, index, example) -> List[Context]: |
1784 | 1811 | if media_type == 'image': |
1785 | 1812 | return image_context |
1786 | 1813 | elif media_type == 'video': |
1787 | | - load_video = partial(load_video_internvl, num_segments=self.video_segments) |
| 1814 | + video_segments = get_env_args('video_segments', int, self.video_segments) |
| 1815 | + load_video = partial(load_video_internvl, num_segments=video_segments) |
1788 | 1816 | return _replace_video2image(load_video, example, lambda i: [f'Frame{i + 1}: '] + image_context) |
1789 | 1817 |
|
1790 | 1818 | def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]: |
@@ -1816,7 +1844,9 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An |
1816 | 1844 | images = example.get('images') |
1817 | 1845 | if images: |
1818 | 1846 | has_video = bool(example.get('videos')) |
1819 | | - pixel_values = [transform_image(image, max_num=1 if has_video else 12) for image in images] |
| 1847 | + input_size = get_env_args('input_size', int, 448) |
| 1848 | + max_num = get_env_args('max_num', int, 1 if has_video else 12) |
| 1849 | + pixel_values = [transform_image(image, input_size, max_num) for image in images] |
1820 | 1850 | num_patches = [pv.shape[0] for pv in pixel_values] |
1821 | 1851 | pixel_values = torch.cat(pixel_values).to(self.model.dtype) |
1822 | 1852 | else: |
@@ -1924,7 +1954,9 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An |
1924 | 1954 | processor = self.tokenizer.processor |
1925 | 1955 | images = example.get('images') or [] |
1926 | 1956 | assert len(images) == 1, 'Florence series models only supports input with a single image.' |
1927 | | - image_tensors = transform_image(images[0]) |
| 1957 | + input_size = get_env_args('input_size', int, 448) |
| 1958 | + max_num = get_env_args('max_num', int, 12) |
| 1959 | + image_tensors = transform_image(images[0], input_size, max_num) |
1928 | 1960 | example['_image'] = image_tensors |
1929 | 1961 |
|
1930 | 1962 | # process bbox |
@@ -2789,6 +2821,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An |
2789 | 2821 | use_image_id = False |
2790 | 2822 | max_slice_nums = 1 # or 2 |
2791 | 2823 |
|
| 2824 | + max_slice_nums = get_env_args('max_slice_nums', int, max_slice_nums) |
2792 | 2825 | input_ids = inputs['input_ids'] |
2793 | 2826 | labels = inputs['labels'] |
2794 | 2827 | idx_list = _findall(input_ids, -100) |
|
0 commit comments