Skip to content

Commit b93d0b6

Browse files
committed
Refactor tensor creation dtype / device control.
This commit makes two changes during model creation: 1. Decouples promote_trainable_params_to_fp32 from model __init__. This is to avoid casting to fp32 to save memory in inference-only mode (#4). 2. Use a context manager to manage default tensor type change. In the previous version, the default tensor type is reset to torch.FloatTensor after creating the vision model, which is technically incorrect and should be the previous default tensor type instead. We implement our own context manager because the official context managers seem to be incomplete at this time (PyTorch 2.0.1): No dtype manager is provided and set_default_device is ineffective to the torch.Tensor calls which are used in fairscale.
1 parent 0043e8f commit b93d0b6

File tree

8 files changed

+103
-38
lines changed

8 files changed

+103
-38
lines changed

accessory/demos/multi_turn.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import gradio as gr
1616

1717
from util.misc import setup_for_distributed, load_pretrained
18+
from util.tensor_type import default_tensor_type
1819
from model.meta import MetaModel
1920
from data.conversation.lib import conv_templates, SeparatorStyle
2021

@@ -50,17 +51,16 @@ def model_worker(
5051
# set the print behavior.
5152
setup_for_distributed(rank == 0)
5253

53-
model = MetaModel(
54-
args.llama_type, args.llama_config, args.tokenizer_path,
55-
with_visual=False, max_seq_len=args.model_max_seq_len,
56-
)
5754
target_dtype = {
5855
"bf16": torch.bfloat16,
5956
"fp16": torch.float16,
6057
}[args.dtype]
61-
for n, p in model.named_parameters():
62-
p.data = p.data.to(target_dtype)
63-
model.cuda().eval()
58+
with default_tensor_type(dtype=target_dtype, device="cuda"):
59+
model = MetaModel(
60+
args.llama_type, args.llama_config, args.tokenizer_path,
61+
with_visual=False, max_seq_len=args.model_max_seq_len,
62+
)
63+
model.eval()
6464
print(f"Loading pretrained weights from {args.pretrained_path}")
6565
load_pretrained(args.pretrained_path, args.pretrained_type, model)
6666
print(f"Model = {str(model)}")

accessory/main_finetune.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import util.misc as misc
3434
from util.misc import NativeScalerWithGradNormCount as NativeScaler
35+
from util.tensor_type import default_tensor_type, promote_trainable_params_to_fp32
3536
from model.meta import MetaModel
3637
from engine_finetune import train_one_epoch
3738
from torch.utils.data import Dataset
@@ -150,8 +151,15 @@ def main(args):
150151
dp_group = fs_init.get_data_parallel_group()
151152

152153
# define the model
153-
model = MetaModel(args.llama_type, args.llama_config,
154-
args.tokenizer_path, with_visual=not args.no_visual)
154+
mixed_precision_dtype = {
155+
"fp16": torch.float16,
156+
"bf16": torch.bfloat16,
157+
"tf32": torch.float32,
158+
}[args.precision]
159+
with default_tensor_type(dtype=mixed_precision_dtype, device="cuda"):
160+
model = MetaModel(args.llama_type, args.llama_config,
161+
args.tokenizer_path, with_visual=not args.no_visual)
162+
promote_trainable_params_to_fp32(model)
155163
print(f"load pretrained from {args.pretrained_path}")
156164
misc.load_pretrained(args.pretrained_path, args.pretrained_type, model)
157165
print("Unwrapped Model = %s" % str(model))
@@ -160,11 +168,6 @@ def main(args):
160168
if args.resume:
161169
misc.resume_stage1(args, model_without_FSDP=model)
162170

163-
mixed_precision_dtype = {
164-
"fp16": torch.float16,
165-
"bf16": torch.bfloat16,
166-
"tf32": torch.float32,
167-
}[args.precision]
168171
TransformerBlock = type(model.llma.layers[0])
169172
# ignored_named_parameters = {name: param for name, param in model.named_parameters() if not param.requires_grad}
170173
# print(ignored_named_parameters.keys())

accessory/main_pretrain.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import util.misc as misc
3434
from util.misc import NativeScalerWithGradNormCount as NativeScaler
35+
from util.tensor_type import default_tensor_type, promote_trainable_params_to_fp32
3536
from model.meta import MetaModel
3637
from engine_pretrain import train_one_epoch, val_one_epoch
3738
from torch.utils.data import Dataset
@@ -147,8 +148,15 @@ def main(args):
147148
dp_group = fs_init.get_data_parallel_group()
148149

149150
# define the model
150-
model = MetaModel(args.llama_type, args.llama_config,
151-
args.tokenizer_path, with_visual=False)
151+
mixed_precision_dtype = {
152+
"fp16": torch.float16,
153+
"bf16": torch.bfloat16,
154+
"tf32": torch.float32,
155+
}[args.precision]
156+
with default_tensor_type(dtype=mixed_precision_dtype, device="cuda"):
157+
model = MetaModel(args.llama_type, args.llama_config,
158+
args.tokenizer_path, with_visual=False)
159+
promote_trainable_params_to_fp32(model)
152160
if args.pretrained_path:
153161
print(f"load pretrained from {args.pretrained_path}")
154162
misc.load_pretrained(args.pretrained_path, args.pretrained_type, model)
@@ -158,11 +166,7 @@ def main(args):
158166
if args.resume:
159167
misc.resume_stage1(args, model_without_FSDP=model)
160168

161-
mixed_precision_dtype = {
162-
"fp16": torch.float16,
163-
"bf16": torch.bfloat16,
164-
"tf32": torch.float32,
165-
}[args.precision]
169+
166170
TransformerBlock = type(model.llma.layers[0])
167171

168172
model = FSDP(

accessory/model/LLM/llama.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from apex.normalization import FusedRMSNorm as RMSNorm
2121
import open_clip
2222

23+
from util.tensor_type import default_tensor_type
2324
import configs.global_configs
2425
if configs.global_configs.USE_FLASH_ATTENTION:
2526
from flash_attn import flash_attn_func
@@ -308,9 +309,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
308309
self.cache_image_words = 0 # for inference
309310
if with_visual:
310311
print("build llama model with clip")
311-
torch.set_default_tensor_type(torch.cuda.HalfTensor)
312-
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
313-
torch.set_default_tensor_type(torch.FloatTensor)
312+
with default_tensor_type(dtype=torch.half):
313+
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
314314
for name, param in self.clip.named_parameters():
315315
param.requires_grad = False
316316
in_dim = self.clip.visual.proj.shape[1]
@@ -334,9 +334,7 @@ def get_trainable_params(self):
334334
def set_default_trainability(self):
335335
for key, value in self.named_parameters():
336336
value.requires_grad = False
337-
value.data = value.data.half()
338337
for key, value in self.get_trainable_params().items():
339-
value.data = value.data.float()
340338
value.requires_grad = True
341339

342340

accessory/model/LLM/llama_adapter.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import configs.global_configs
2525
if configs.global_configs.USE_FLASH_ATTENTION:
2626
from flash_attn import flash_attn_func
27+
from util.tensor_type import default_tensor_type
2728

2829
default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5))
2930

@@ -349,9 +350,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
349350
self.image_words = 0
350351
if with_visual:
351352
print("build llama model with clip")
352-
torch.set_default_tensor_type(torch.cuda.HalfTensor)
353-
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
354-
torch.set_default_tensor_type(torch.FloatTensor)
353+
with default_tensor_type(dtype=torch.half):
354+
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
355355
for name, param in self.clip.named_parameters():
356356
param.requires_grad = False
357357
in_dim = self.clip.visual.proj.shape[1]
@@ -401,9 +401,7 @@ def get_trainable_params(self):
401401
def set_default_trainability(self):
402402
for key, value in self.named_parameters():
403403
value.requires_grad = False
404-
value.data = value.data.half()
405404
for key, value in self.get_trainable_params().items():
406-
value.data = value.data.float()
407405
value.requires_grad = True
408406

409407

accessory/model/LLM/llama_peft.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ColumnParallelLinear
1818
)
1919
from ..peft import LoraColumnParallelLinear, LoraRowParallelLinear
20+
from util.tensor_type import default_tensor_type
2021

2122
from apex.normalization import FusedRMSNorm as RMSNorm
2223
import open_clip
@@ -323,9 +324,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
323324
self.cache_image_words = 0 # for inference
324325
if with_visual:
325326
print("build llama model with clip")
326-
torch.set_default_tensor_type(torch.cuda.HalfTensor)
327-
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
328-
torch.set_default_tensor_type(torch.FloatTensor)
327+
with default_tensor_type(dtype=torch.half):
328+
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
329329
for name, param in self.clip.named_parameters():
330330
param.requires_grad = False
331331
in_dim = self.clip.visual.proj.shape[1]
@@ -351,9 +351,7 @@ def get_trainable_params(self):
351351
def set_default_trainability(self):
352352
for key, value in self.named_parameters():
353353
value.requires_grad = False
354-
value.data = value.data.half()
355354
for key, value in self.get_trainable_params().items():
356-
value.data = value.data.float()
357355
value.requires_grad = True
358356

359357

accessory/model/LLM/llama_qformerv2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,7 @@ def get_trainable_params(self):
337337
def set_default_trainability(self):
338338
for key, value in self.named_parameters():
339339
value.requires_grad = False
340-
value.data = value.data.half()
341340
for key, value in self.get_trainable_params().items():
342-
value.data = value.data.float()
343341
value.requires_grad = True
344342

345343

accessory/util/tensor_type.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from types import TracebackType
2+
from typing import Any, Optional
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
class default_tensor_type:
8+
_tensor_type_stack = [(torch.float, "cpu")]
9+
10+
def __init__(
11+
self,
12+
dtype: Optional[torch.dtype] = None,
13+
device: Optional[str] = None,
14+
) -> None:
15+
# Only limited combinations are supported.
16+
assert device is None or device in ["cpu", "cuda"]
17+
assert dtype is None or dtype in [torch.float, torch.bfloat16, torch.half]
18+
self.dtype, self.device = dtype, device
19+
20+
def __enter__(self) -> None:
21+
dtype, device = self.dtype, self.device
22+
if dtype is None:
23+
dtype = default_tensor_type._tensor_type_stack[-1][0]
24+
if device is None:
25+
device = default_tensor_type._tensor_type_stack[-1][1]
26+
default_tensor_type._tensor_type_stack.append((dtype, device))
27+
28+
# We use all 3 calls since the new apis (set_default_device, set_default_dtype)
29+
# seems to be ineffective sometimes (e.g., set_default_device is ineffective to
30+
# torch.Tensor calls).
31+
torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device))
32+
torch.set_default_device(device)
33+
torch.set_default_dtype(dtype)
34+
35+
def __exit__(
36+
self,
37+
exc_type: Optional[type[BaseException]],
38+
exc_val: Optional[BaseException],
39+
exc_tb: Optional[TracebackType],
40+
) -> None:
41+
default_tensor_type._tensor_type_stack.pop()
42+
dtype, device = default_tensor_type._tensor_type_stack[-1]
43+
44+
torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device))
45+
torch.set_default_device(device)
46+
torch.set_default_dtype(dtype)
47+
48+
@staticmethod
49+
def get_tensor_type(dtype: torch.dtype, device: str) -> Any:
50+
return {
51+
(torch.float, "cpu"): torch.FloatTensor,
52+
(torch.bfloat16, "cpu"): torch.BFloat16Tensor,
53+
(torch.half, "cpu"): torch.HalfTensor,
54+
(torch.float, "cuda"): torch.cuda.FloatTensor,
55+
(torch.bfloat16, "cuda"): torch.cuda.BFloat16Tensor,
56+
(torch.half, "cuda"): torch.cuda.HalfTensor,
57+
}[(dtype, device)]
58+
59+
60+
def promote_trainable_params_to_fp32(model: nn.Module) -> None:
61+
for param in model.parameters():
62+
if param.requires_grad:
63+
if param.is_floating_point() and torch.finfo(param.dtype).bits < 32:
64+
param.data = param.data.float()
65+
if param.is_complex() and torch.finfo(param.dtype).bits < 32:
66+
param.data = param.data.to(torch.complex64)

0 commit comments

Comments
 (0)