Skip to content

Commit ea04127

Browse files
author
hanjian.thu123
committed
[update] add enable_model_cache, default set to 0
1 parent b65c48c commit ea04127

File tree

2 files changed

+28
-24
lines changed

2 files changed

+28
-24
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ Each "[h_div_w_template1]_[num_examples].jsonl" file contains lines of dumped js
110110
{
111111
"image_path": "path/to/image, required",
112112
"h_div_w": "float value of h_div_w for the image, required",
113-
"long_caption": long_caption of the image, required",
113+
"long_caption": long caption of the image, required",
114114
"long_caption_type": "InternVL 2.0, required",
115-
"short_caption": "short of the image, optional",
115+
"text": "short caption of the image, optional",
116116
"short_caption_type": "user prompt, , optional"
117117
}
118118
```

tools/run_infinity.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -275,32 +275,34 @@ def load_transformer(vae, args):
275275
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
276276
model_path = args.model_path
277277
if args.checkpoint_type == 'torch':
278-
assert ('ar-' in model_path) or ('slim-' in model_path)
279-
# copy large model to local, save slim to local, and copy slim to nas, and load local slim model
278+
# copy large model to local; save slim to local; and copy slim to nas; load local slim model
280279
if osp.exists(args.cache_dir):
281280
local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
282281
else:
283282
local_model_path = model_path
284-
slim_model_path = model_path.replace('ar-', 'slim-')
285-
local_slim_model_path = local_model_path.replace('ar-', 'slim-')
286-
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
287-
print(f'model_path: {model_path}, slim_model_path: {slim_model_path}')
288-
print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}')
289-
if not osp.exists(local_slim_model_path):
290-
if osp.exists(slim_model_path):
291-
print(f'copy {slim_model_path} to {local_slim_model_path}')
292-
shutil.copyfile(slim_model_path, local_slim_model_path)
293-
else:
294-
if not osp.exists(local_model_path):
295-
print(f'copy {model_path} to {local_model_path}')
296-
shutil.copyfile(model_path, local_model_path)
297-
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
298-
print(f'copy {local_slim_model_path} to {slim_model_path}')
299-
if not osp.exists(slim_model_path):
300-
shutil.copyfile(local_slim_model_path, slim_model_path)
301-
os.remove(local_model_path)
302-
os.remove(model_path)
303-
slim_model_path = local_slim_model_path
283+
if args.enable_model_cache:
284+
slim_model_path = model_path.replace('ar-', 'slim-')
285+
local_slim_model_path = local_model_path.replace('ar-', 'slim-')
286+
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
287+
print(f'model_path: {model_path}, slim_model_path: {slim_model_path}')
288+
print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}')
289+
if not osp.exists(local_slim_model_path):
290+
if osp.exists(slim_model_path):
291+
print(f'copy {slim_model_path} to {local_slim_model_path}')
292+
shutil.copyfile(slim_model_path, local_slim_model_path)
293+
else:
294+
if not osp.exists(local_model_path):
295+
print(f'copy {model_path} to {local_model_path}')
296+
shutil.copyfile(model_path, local_model_path)
297+
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
298+
print(f'copy {local_slim_model_path} to {slim_model_path}')
299+
if not osp.exists(slim_model_path):
300+
shutil.copyfile(local_slim_model_path, slim_model_path)
301+
os.remove(local_model_path)
302+
os.remove(model_path)
303+
slim_model_path = local_slim_model_path
304+
else:
305+
slim_model_path = model_path
304306
print(f'load checkpoint from {slim_model_path}')
305307

306308
if args.model_type == 'infinity_2b':
@@ -358,9 +360,11 @@ def add_common_arguments(parser):
358360
parser.add_argument('--use_flex_attn', type=int, default=0, choices=[0,1])
359361
parser.add_argument('--enable_positive_prompt', type=int, default=0, choices=[0,1])
360362
parser.add_argument('--cache_dir', type=str, default='/dev/shm')
363+
parser.add_argument('--enable_model_cache', type=int, default=0, choices=[0,1])
361364
parser.add_argument('--checkpoint_type', type=str, default='torch')
362365
parser.add_argument('--seed', type=int, default=0)
363366
parser.add_argument('--bf16', type=int, default=1, choices=[0,1])
367+
364368

365369

366370
if __name__ == '__main__':

0 commit comments

Comments
 (0)