Skip to content

Commit 481e00b

Browse files
committed
add pp_dim, distributed, num_gpus, num_nodes as cmd line args
1 parent d1ab6e0 commit 481e00b

File tree

2 files changed

+45
-36
lines changed

2 files changed

+45
-36
lines changed

torchchat/cli/builder.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,12 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
20-
21-
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
22-
2319
from torch.distributed.device_mesh import DeviceMesh
2420

25-
from torchtune.models.convert_weights import meta_to_tune
26-
27-
from torchtune.training import set_default_dtype
21+
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
2822

2923
from torchchat.model import Model, ModelArgs, ModelType
3024

31-
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
32-
3325
from torchchat.model_config.model_config import resolve_model_config
3426
from torchchat.utils.build_utils import (
3527
device_sync,
@@ -40,6 +32,14 @@
4032
from torchchat.utils.measure_time import measure_time
4133
from torchchat.utils.quantize import quantize_model
4234

35+
from torchtune.models.convert_weights import meta_to_tune
36+
37+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
38+
39+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
40+
41+
from torchtune.training import set_default_dtype
42+
4343

4444
@dataclass
4545
class BuilderArgs:
@@ -55,7 +55,10 @@ class BuilderArgs:
5555
device: Optional[str] = None
5656
precision: torch.dtype = torch.float32
5757
setup_caches: bool = False
58-
use_distributed: bool = False
58+
distributed: bool = False
59+
num_gpus: int = 1
60+
num_nodes: int = 1
61+
pp_dim: int = 1
5962
is_chat_model: bool = False
6063
prefill_possible: bool = False
6164
dynamic_shapes: bool = False
@@ -156,7 +159,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
156159
dtype = torch.float16
157160
else:
158161
dtype = name_to_dtype(args.dtype, args.device)
159-
162+
# distributed args
163+
distributed = getattr(args, "distributed", False)
164+
num_gpus = getattr(args, "num_gpus", 1)
165+
num_nodes = getattr(args, "num_nodes", 1)
166+
pp_dim = getattr(args, "pp_dim", 1)
160167
return cls(
161168
checkpoint_dir=checkpoint_dir,
162169
checkpoint_path=checkpoint_path,
@@ -170,7 +177,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
170177
device=args.device,
171178
precision=dtype,
172179
setup_caches=(output_dso_path or output_pte_path),
173-
use_distributed=args.distributed,
180+
distributed=distributed,
181+
num_gpus=num_gpus,
182+
num_nodes=num_nodes,
183+
pp_dim=pp_dim,
174184
is_chat_model=is_chat_model,
175185
dynamic_shapes=getattr(args, "dynamic_shapes", False),
176186
max_seq_length=getattr(args, "max_seq_length", None),
@@ -400,10 +410,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
400410
# does not host any actual values, need to reinitialize them in the actual
401411
# device. Only do those buffer initialization, without initializing the entire
402412
# model.
403-
decoder_config = model.config.transformer_args['decoder']
404-
head_dim = decoder_config['embed_dim'] // decoder_config['num_heads']
405-
max_seq_len = decoder_config['max_seq_len']
406-
rope_base = decoder_config['rope_base']
413+
decoder_config = model.config.transformer_args["decoder"]
414+
head_dim = decoder_config["embed_dim"] // decoder_config["num_heads"]
415+
max_seq_len = decoder_config["max_seq_len"]
416+
rope_base = decoder_config["rope_base"]
407417
for submodule in model.modules():
408418
if isinstance(submodule, Llama3ScaledRoPE):
409419
submodule.__init__(head_dim, max_seq_len, rope_base)
@@ -491,6 +501,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
491501
model = model.to(device=builder_args.device, dtype=builder_args.precision)
492502
return model.eval()
493503

504+
494505
def _initialize_model(
495506
builder_args: BuilderArgs,
496507
quantize,

torchchat/generate.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,6 @@
2424

2525
from PIL import Image
2626

27-
# torchtune model definition dependencies
28-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
29-
30-
from torchtune.generation import sample as tune_sample
31-
from torchtune.models.llama3 import llama3_tokenizer
32-
33-
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
34-
from torchtune.training import set_default_dtype
35-
3627
from torchchat.cli.builder import (
3728
_initialize_model,
3829
_initialize_tokenizer,
@@ -43,6 +34,15 @@
4334
from torchchat.utils.build_utils import device_sync, set_precision
4435
from torchchat.utils.device_info import get_device_info
4536

37+
# torchtune model definition dependencies
38+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
39+
40+
from torchtune.generation import sample as tune_sample
41+
from torchtune.models.llama3 import llama3_tokenizer
42+
43+
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
44+
from torchtune.training import set_default_dtype
45+
4646

4747
class _ChatFormatter(ABC):
4848
def __init__(self, tokenizer):
@@ -239,23 +239,17 @@ def __init__(
239239
self.is_torchtune_model = generator_args.is_torchtune_model
240240
self.dtype = builder_args.precision
241241

242-
# global print
243-
# from tp import maybe_init_dist
244-
# rank = maybe_init_dist()
245-
# use_distributed = False
246242
self.rank: Optional[int] = None
247-
# if use_distributed:
248-
# if rank != 0:
249-
# # only print on rank 0
250-
# print = lambda *args, **kwargs: None
251243

252244
print(
253245
f"Using device={self.builder_args.device} {get_device_info(self.builder_args.device)}"
254246
)
255247
set_precision(self.builder_args.precision)
256-
if builder_args.use_distributed:
248+
if builder_args.distributed:
249+
print(f"Using distributed={builder_args.distributed}")
257250
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
258251
torch.cuda.set_device(device)
252+
assert False, "Distributed is not supported yet"
259253
self.is_speculative = self.speculative_builder_args.checkpoint_path is not None
260254

261255
if generator_args.chat_mode and not self.builder_args.is_chat_model:
@@ -938,7 +932,8 @@ def chat(
938932
TransformerCrossAttentionLayer,
939933
TransformerSelfAttentionLayer,
940934
)
941-
decoder = self.model.model.decoder
935+
936+
decoder = self.model.model.decoder
942937
for m in reversed(list(decoder.modules())):
943938
if isinstance(m, TransformerSelfAttentionLayer) or isinstance(
944939
m, TransformerCrossAttentionLayer
@@ -984,7 +979,10 @@ def chat(
984979
# `is_torchtune_model` is a misnomer since it doesn't capture all
985980
# torchtune models (i.e. Flamingo)
986981
# See Issue: https://github.com/pytorch/torchchat/issues/1273
987-
elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo:
982+
elif (
983+
not generator_args.is_torchtune_model
984+
and self.model.config.model_type != ModelType.Flamingo
985+
):
988986
max_seq_length = min(
989987
encoded.size(0) + generator_args.max_new_tokens,
990988
(

0 commit comments

Comments
 (0)