diff --git a/README.md b/README.md index f6441d2e7..4b910e575 100644 --- a/README.md +++ b/README.md @@ -575,7 +575,7 @@ We really value our community and the contributions made by our wonderful users. To connect with us and other community members, we invite you to join our Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can: * Head to the `#torchchat-general` channel for general questions, discussion, and community support. -* Join the `#torchchat-contribution` channel if you're interested in contributing directly to project development. +* Join the `#torchchat-contributors` channel if you're interested in contributing directly to project development. Looking forward to discussing with you about torchchat future! diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 6344509d8..635789de6 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -9,26 +9,41 @@ set -eou pipefail # Install required python dependencies for developing # Dependencies are defined in .pyproject.toml -PYTHON_EXECUTABLE=${PYTHON_EXECUTABLE:-python} -if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]]; +if [ -z "${PYTHON_EXECUTABLE:-}" ]; then - PYTHON_EXECUTABLE=python3 + if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]]; + then + PYTHON_EXECUTABLE=python3 + else + PYTHON_EXECUTABLE=python + fi fi - -# Check python version. Expect 3.10.x or 3.11.x -printf "import sys\nif sys.version_info.major != 3 or sys.version_info.minor < 10 :\n\tprint('Please use Python >=3.10');sys.exit(1)\n" | $PYTHON_EXECUTABLE -if [[ $? -ne 0 ]] +echo "Using python executable: $PYTHON_EXECUTABLE" + +PYTHON_SYS_VERSION="$($PYTHON_EXECUTABLE -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")" +# Check python version. Expect at least 3.10.x +if ! $PYTHON_EXECUTABLE -c " +import sys +if sys.version_info < (3, 10): + sys.exit(1) +"; then + echo "Python version must be at least 3.10.x. Detected version: $PYTHON_SYS_VERSION" exit 1 fi if [[ "$PYTHON_EXECUTABLE" == "python" ]]; then PIP_EXECUTABLE=pip -else +elif [[ "$PYTHON_EXECUTABLE" == "python3" ]]; +then PIP_EXECUTABLE=pip3 +else + PIP_EXECUTABLE=pip${PYTHON_SYS_VERSION} fi +echo "Using pip executable: $PIP_EXECUTABLE" + # # First install requirements in install/requirements.txt. Older torch may be # installed from the dependency of other models. It will be overridden by diff --git a/install/requirements.txt b/install/requirements.txt index d051d29cd..8fb1832ba 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -14,7 +14,6 @@ snakeviz sentencepiece # numpy version range required by GGUF util numpy >= 1.17, < 2.0 -gguf blobfile tomli >= 1.1.0 ; python_version < "3.11" openai diff --git a/tokenizer/base64.h b/tokenizer/base64.h index dfeefef55..12b8703a8 100644 --- a/tokenizer/base64.h +++ b/tokenizer/base64.h @@ -25,6 +25,7 @@ #pragma once #include +#include #include #include diff --git a/torchchat.py b/torchchat.py index 35cdcabae..1eeee0120 100644 --- a/torchchat.py +++ b/torchchat.py @@ -6,7 +6,7 @@ import argparse import logging -import subprocess +import signal import sys # MPS ops missing with Multimodal torchtune @@ -25,7 +25,15 @@ default_device = "cpu" +def signal_handler(sig, frame): + print("\nInterrupted by user. Bye!\n") + sys.exit(0) + + if __name__ == "__main__": + # Set the signal handler for SIGINT + signal.signal(signal.SIGINT, signal_handler) + # Initialize the top-level parser parser = argparse.ArgumentParser( prog="torchchat", diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index a7a22a1e8..f67cb9d0a 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,12 +16,6 @@ import torch._inductor.config import torch.nn as nn -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.elastic.utils.distributed import get_free_port - -from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama - from torchchat.model import Model, ModelArgs, ModelType from torchchat.model_config.model_config import resolve_model_config @@ -464,77 +458,20 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: return model -def _maybe_init_distributed( - builder_args: BuilderArgs, -) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: - """ - Initialize distributed related setups if the user specified - using distributed inference. If not, this is a no-op. - - Args: - builder_args (:class:`BuilderArgs`): - Command args for model building. - Returns: - Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: - - The first element is an optional DeviceMesh object, - which which describes the mesh topology of devices for the DTensor. - - The second element is an optional ParallelDims object, - which represents the parallel dimensions configuration. - """ - if not builder_args.use_distributed: - return None, None - dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line - - world_mesh, parallel_dims = launch_distributed(dist_config) - - assert ( - world_mesh is not None and parallel_dims is not None - ), f"failed to launch distributed using {dist_config}" - - return world_mesh, parallel_dims - - -def _maybe_parallelize_model( - model: nn.Module, - builder_args: BuilderArgs, - world_mesh: DeviceMesh, - parallel_dims: ParallelDims, -) -> nn.Module: - """ - We parallelize the module and load the distributed checkpoint to the model - if the user specifies using distributed inference. If not, this is a no-op. - - Args: - model (:class:`nn.Module`): - Module to be parallelized. - builder_args (:class:`BuilderArgs`): - Command args for model building. - world_mesh (:class:`DeviceMesh`): - Object which describes the mesh topology - of devices for the DTensor. - parallel_dims (:class:`ParallelDims`): - Object which represents the parallel dimensions configuration. - Returns: - A :class:`nn.Module` object which is parallelized and checkpoint loaded - if the user specifies using distributed inference. - """ - if world_mesh is None: - return model - assert parallel_dims is not None - print("Applying model parallel to model ...") - parallelize_llama(model, world_mesh, parallel_dims) - return load_checkpoints_to_model(model, builder_args, world_mesh) - - def _load_model(builder_args: BuilderArgs) -> Model: - # world_mesh, parallel_dims = _maybe_init_distributed(builder_args) if builder_args.gguf_path: model = _load_model_gguf(builder_args) - # elif builder_args.use_distributed: - # model = _init_model_on_meta_device(builder_args) else: model = _load_model_default(builder_args) - # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) + + if builder_args.dso_path or builder_args.aoti_package_path: + # AOTI-compoiled model will load its own weights. + # Release weights here to avoid OOM + import gc + if hasattr(model, "model"): + model.model = None + gc.collect() + torch.cuda.empty_cache() model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() @@ -584,6 +521,12 @@ def _initialize_model( # attributes will NOT be seen on by AOTI-compiled forward # function, e.g. calling model.setup_cache will NOT touch # AOTI compiled and maintained model buffers such as kv_cache. + # Using cpp runner to run AOTI compiled model is recommended. + + def do_nothing(max_batch_size, max_seq_length): + pass + model.setup_caches = do_nothing + model.forward = torch._export.aot_load( str(builder_args.dso_path.absolute()), builder_args.device ) @@ -617,6 +560,11 @@ def _initialize_model( aoti_compiled_model = load_package( str(builder_args.aoti_package_path.absolute()) ) + + def do_nothing(max_batch_size, max_seq_length): + pass + model.setup_caches = do_nothing + model.forward = aoti_compiled_model metadata = aoti_compiled_model.get_metadata() builder_args.device = metadata["AOTI_DEVICE_KEY"] @@ -686,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str: return "TikToken" if tokenizers: return "Tokenizers" - return "SentencePiece" \ No newline at end of file + return "SentencePiece" diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 740f344a8..3a7c85937 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -5,26 +5,24 @@ # LICENSE file in the root directory of this source tree. import argparse +import importlib.metadata import json import logging import os import sys from pathlib import Path -import torch - -from torchchat.cli.download import download_and_convert, is_model_downloaded - from torchchat.utils.build_utils import ( allowable_dtype_names, allowable_params_table, - get_device_str, ) logging.basicConfig(level=logging.INFO, format="%(message)s") logger = logging.getLogger(__name__) default_device = os.getenv("TORCHCHAT_DEVICE", "fast") +default_dtype = os.getenv("TORCHCHAT_PRECISION", "fast") + default_model_dir = Path( os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache") ).expanduser() @@ -42,6 +40,9 @@ # Handle CLI arguments that are common to a majority of subcommands. def check_args(args, verb: str) -> None: + # Local import to avoid unnecessary expensive imports + from torchchat.cli.download import download_and_convert, is_model_downloaded + # Handle model download. Skip this for download, since it has slightly # different semantics. if ( @@ -150,9 +151,9 @@ def _add_model_config_args(parser, verb: str) -> None: model_config_parser.add_argument( "--dtype", - default="fast", + default=None, choices=allowable_dtype_names(), - help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast", + help="Override the dtype of the model. Options: bf16, fp16, fp32, fast16, fast", ) model_config_parser.add_argument( "--quantize", @@ -166,9 +167,9 @@ def _add_model_config_args(parser, verb: str) -> None: model_config_parser.add_argument( "--device", type=str, - default=default_device, + default=None, choices=["fast", "cpu", "cuda", "mps"], - help="Hardware device to use. Options: cpu, cuda, mps", + help="Hardware device to use. Options: fast, cpu, cuda, mps", ) @@ -498,9 +499,10 @@ def _add_speculative_execution_args(parser) -> None: def arg_init(args): - if not (torch.__version__ > "2.3"): + torch_version = importlib.metadata.version("torch") + if not torch_version or (torch_version <= "2.3"): raise RuntimeError( - f"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release" + f"You are using PyTorch {torch_version}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release" ) if sys.version_info.major != 3 or sys.version_info.minor < 10: @@ -513,17 +515,34 @@ def arg_init(args): if isinstance(args.quantize, str): args.quantize = json.loads(args.quantize) - # if we specify dtype in quantization recipe, replicate it as args.dtype - args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype) + # if we specify dtype in quantization recipe, allow args.dtype top override if specified + if args.dtype is None: + args.dtype = args.quantize.get("precision", {}).get("dtype", default_dtype) + else: + precision_handler = args.quantize.get("precision", None) + if precision_handler: + if precision_handler["dtype"] != args.dtype: + print('overriding json-specified dtype {precision_handler["dtype"]} with cli dtype {args.dtype}') + precision_handler["dtype"] = args.dtype if getattr(args, "output_pte_path", None): - if args.device not in ["cpu", "fast"]: + if args.device not in [None, "cpu", "fast"]: raise RuntimeError("Device not supported by ExecuTorch") args.device = "cpu" else: - args.device = get_device_str( - args.quantize.get("executor", {}).get("accelerator", args.device) - ) + # Localized import to minimize expensive imports + from torchchat.utils.build_utils import get_device_str + + if args.device is None or args.device == "fast": + args.device = get_device_str( + args.quantize.get("executor", {}).get("accelerator", default_device) + ) + else: + executor_handler = args.quantize.get("executor", None) + if executor_handler: + if executor_handler["accelerator"] != args.device: + print('overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}') + executor_handler["accelerator"] = args.device if "mps" in args.device: if getattr(args, "compile", False) or getattr(args, "compile_prefill", False): @@ -534,5 +553,8 @@ def arg_init(args): vars(args)["compile_prefill"] = False if hasattr(args, "seed") and args.seed: + # Localized import to minimize expensive imports + import torch + torch.manual_seed(args.seed) return args diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index f95cbdaef..122ab0f28 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -11,25 +11,23 @@ from pathlib import Path from typing import Optional -import torch - -from torchchat.model import TransformerArgs - # support running without installing as a package wd = Path(__file__).parent.parent sys.path.append(str(wd.resolve())) sys.path.append(str((wd / "build").resolve())) -from torchchat.model import ModelArgs - -@torch.inference_mode() def convert_hf_checkpoint( *, model_dir: Optional[Path] = None, model_name: Optional[str] = None, remove_bin_files: bool = False, ) -> None: + + # Local imports to avoid expensive imports + from torchchat.model import ModelArgs, TransformerArgs + import torch + if model_dir is None: model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf") if model_name is None: @@ -41,27 +39,23 @@ def convert_hf_checkpoint( config = TransformerArgs.from_params(config_args) print(f"Model config {config.__dict__}") - # Load the json file containing weight mapping + # Find all candidate weight mapping index files model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))] - assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files" - if len(model_map_json_matches): - model_map_json = model_map_json_matches[0] - else: - model_map_json = model_dir / "pytorch_model.bin.index.json" # If there is no weight mapping, check for a consolidated model and # tokenizer we can move. Llama 2 and Mistral have weight mappings, while # Llama 3 has a consolidated model and tokenizer. # Otherwise raise an error. - if not model_map_json.is_file(): + if not model_map_json_matches: consolidated_pth = model_dir / "original" / "consolidated.00.pth" tokenizer_pth = model_dir / "original" / "tokenizer.model" if consolidated_pth.is_file() and tokenizer_pth.is_file(): # Confirm we can load it - loaded_result = torch.load( - str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True - ) - del loaded_result # No longer needed + with torch.inference_mode(): + loaded_result = torch.load( + str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True + ) + del loaded_result # No longer needed print(f"Moving checkpoint to {model_dir / 'model.pth'}.") os.rename(consolidated_pth, model_dir / "model.pth") os.rename(tokenizer_pth, model_dir / "tokenizer.model") @@ -69,11 +63,30 @@ def convert_hf_checkpoint( return else: raise RuntimeError( - f"Could not find {model_map_json} or {consolidated_pth} plus {tokenizer_pth}" + f"Could not find a valid model weight map or {consolidated_pth} plus {tokenizer_pth}" ) - with open(model_map_json) as json_map: - bin_index = json.load(json_map) + # Load the json file(s) containing weight mapping + # + # NOTE: If there are multiple index files, there are two possibilities: + # 1. The files could be mapped to different weight format files (e.g. .bin + # vs .safetensors) + # 2. The files could be split subsets of the mappings that need to be + # merged + # + # In either case, we can simply keep the mappings where the target file is + # valid in the model dir. + bin_index = {} + for weight_map_file in model_map_json_matches: + with open(weight_map_file, "r") as handle: + weight_map = json.load(handle) + valid_mappings = { + k: model_dir / v + for (k, v) in weight_map.get("weight_map", {}).items() + if (model_dir / v).is_file() + } + bin_index.update(valid_mappings) + bin_files = set(bin_index.values()) weight_map = { "model.embed_tokens.weight": "tok_embeddings.weight", @@ -97,7 +110,6 @@ def convert_hf_checkpoint( "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", } - bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()} def permute(w, n_heads): return ( @@ -130,7 +142,8 @@ def load_safetensors(): state_dict = None for loader in loaders: try: - state_dict = loader() + with torch.inference_mode(): + state_dict = loader() break except Exception: continue @@ -173,7 +186,6 @@ def load_safetensors(): os.remove(file) -@torch.inference_mode() def convert_hf_checkpoint_to_tune( *, model_dir: Optional[Path] = None, diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index f145c93fb..4da2bc390 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -35,11 +35,12 @@ def _download_hf_snapshot( model_info = model_info(model_config.distribution_path, token=hf_token) model_fnames = [f.rfilename for f in model_info.siblings] - # Check the model config for preference between safetensors and pth + # Check the model config for preference between safetensors and pth/bin has_pth = any(f.endswith(".pth") for f in model_fnames) + has_bin = any(f.endswith(".bin") for f in model_fnames) has_safetensors = any(f.endswith(".safetensors") for f in model_fnames) - # If told to prefer safetensors, ignore pth files + # If told to prefer safetensors, ignore pth/bin files if model_config.prefer_safetensors: if not has_safetensors: print( @@ -47,10 +48,10 @@ def _download_hf_snapshot( file=sys.stderr, ) exit(1) - ignore_patterns = "*.pth" + ignore_patterns = ["*.pth", "*.bin"] # If the model has both, prefer pth files over safetensors - elif has_pth and has_safetensors: + elif (has_pth or has_bin) and has_safetensors: ignore_patterns = "*safetensors*" # Otherwise, download everything @@ -110,6 +111,8 @@ def _download_direct( def download_and_convert( model: str, models_dir: Path, hf_token: Optional[str] = None ) -> None: + if model is None: + raise ValueError("'download' command needs a model name or alias.") model_config = resolve_model_config(model) model_dir = models_dir / model_config.name @@ -234,4 +237,8 @@ def where_main(args) -> None: # Subcommand to download model artifacts. def download_main(args) -> None: - download_and_convert(args.model, args.model_directory, args.hf_token) + try: + download_and_convert(args.model, args.model_directory, args.hf_token) + except ValueError as e: + print(e, file=sys.stderr) + sys.exit(1) diff --git a/torchchat/generate.py b/torchchat/generate.py index be6a2e819..66f26ff9f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -71,11 +71,11 @@ def encode_header(self, role) -> List[int]: def encode_message(self, message) -> List[int]: tokens = self.encode_header(message["role"]) - if type(message["content"]) is str: + if isinstance(message["content"], str): tokens.extend( self.tokenizer.encode(message["content"], bos=False, eos=False) ) - elif type(message["content"]) is list: + elif isinstance(message["content"], list): for content in message["content"]: if content["type"] == "text": tokens.extend( @@ -190,7 +190,7 @@ def from_args(cls, args): for image_prompt in image_prompts if (not os.path.exists(image_prompt)) ] - if len(non_existent_image_prompts): + if non_existent_image_prompts: raise RuntimeError( f"Image prompt {non_existent_image_prompts} does not exist" ) @@ -238,7 +238,7 @@ def __init__( draft_quantize: bool, ): torch._inductor.config.coordinate_descent_tuning = ( - False if builder_args.device == "cpu" else True + builder_args.device != "cpu" ) torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future @@ -591,6 +591,7 @@ def generate( Dict[str, Any] ] = None, # List of Image prompt tensors for multimodal models start_pos: int = 0, + skip_cache_setup: bool = False, draft_model: Model, speculate_k: Optional[int] = 8, sequential_prefill=True, @@ -614,26 +615,27 @@ def generate( max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length) # set up caches only if first inference if start_pos == 0: - model = model.to(device=device) - with torch.device(device): - if ( - self.is_torchtune_model - or self.model.config.model_type == ModelType.Flamingo - ): - # 6404 is one-gpu affordable max_seq_length for single image input - model.setup_caches( - batch_size=1, - dtype=self.dtype, - encoder_max_seq_len=6404, - decoder_max_seq_len=max_seq_length, - ) - else: - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - if is_speculative and draft_model is not model: - draft_model.setup_caches( - max_batch_size=1, - max_seq_length=max_seq_length, - ) + if not skip_cache_setup: + model = model.to(device=device) + with torch.device(device): + if ( + self.is_torchtune_model + or self.model.config.model_type == ModelType.Flamingo + ): + # 6404 is one-gpu affordable max_seq_length for single image input + model.setup_caches( + batch_size=1, + dtype=self.dtype, + encoder_max_seq_len=6404, + decoder_max_seq_len=max_seq_length, + ) + else: + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if is_speculative and draft_model is not model: + draft_model.setup_caches( + max_batch_size=1, + max_seq_length=max_seq_length, + ) if model.config.model_type == ModelType.Flamingo: model.reset_caches() @@ -915,13 +917,6 @@ def chat( ] ) if generator_args.compile: - if ( - self.is_speculative and self.builder_args.use_distributed - ): # and ("cuda" in builder_args.device): - torch._inductor.config.triton.cudagraph_trees = ( - False # Bug with cudagraph trees in this case - ) - if self.builder_args.device == "cpu": if generator_args.max_autotune: kwargs = {"mode": "max-autotune"} @@ -1002,11 +997,8 @@ def chat( max_seq_length, ) - max_seq_length = ( - max_seq_length + self.speculative_builder_args.speculate_k + 1 - if self.draft_model is not None - else max_seq_length - ) + if self.draft_model is not None: + max_seq_length += self.speculative_builder_args.speculate_k + 1 aggregate_metrics = { "tokens_per_sec": [], @@ -1023,6 +1015,7 @@ def chat( ) for i in range(num_samples): device_sync(device=self.builder_args.device) + is_first_sample: bool = i == 0 if generator_args.chat_mode: prompt = input("User: ") if prompt == "/bye": @@ -1048,7 +1041,7 @@ def chat( ] ) self.system_prompt = None - elif i == 0: + elif is_first_sample: encoded = self.chat_formatter.encode_dialog_prompt( [{"role": "user", "content": prompt}] ) @@ -1094,9 +1087,7 @@ def callback(x, *, done_generating=False): torch._inductor.config.profiler_mark_wrapper_call = True torch._inductor.config.cpp.enable_kernel_profile = True - if (i != generator_args.num_samples - 1 or not self.profile) or ( - self.builder_args.use_distributed and self.rank != 0 - ): + if i != generator_args.num_samples - 1 or not self.profile: import contextlib prof = contextlib.nullcontext() @@ -1119,6 +1110,7 @@ def callback(x, *, done_generating=False): top_k=generator_args.top_k, sequential_prefill=generator_args.sequential_prefill, start_pos=start_pos, + skip_cache_setup=not is_first_sample, max_seq_length=max_seq_length, ) for token_tensor, metrics in generator_func: @@ -1128,7 +1120,7 @@ def callback(x, *, done_generating=False): if metrics is not None: aggregate_metrics.update(metrics) yield token_tensor, metrics - jit_compile = (i == 0) and ( + jit_compile = is_first_sample and ( generator_args.compile or generator_args.compile_prefill ) compilation_time = time.perf_counter() - t0 @@ -1139,10 +1131,7 @@ def callback(x, *, done_generating=False): print(prof.key_averages().table(sort_by="self_cpu_time_total")) else: print(prof.key_averages().table(sort_by="self_cuda_time_total")) - if self.builder_args.use_distributed: - prof.export_chrome_trace(f"{self.profile}_rank_{self.rank}.json") - else: - prof.export_chrome_trace(f"{self.profile}.json") + prof.export_chrome_trace(f"{self.profile}.json") if start_pos >= max_seq_length: print( @@ -1200,12 +1189,27 @@ def callback(x, *, done_generating=False): f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}" ) - print( - f"\n Average tokens/sec (total): {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f} \ - \nAverage tokens/sec (first token): {torch.mean(torch.tensor(aggregate_metrics['first_token_per_sec'])).item():.2f} \ - \nAverage tokens/sec (next tokens): {torch.mean(torch.tensor(aggregate_metrics['next_tokens_per_sec'])).item():.2f} \n\ + avg_tokens_sec = torch.mean( + torch.tensor(aggregate_metrics["tokens_per_sec"]) + ).item() + avg_first_token_sec = torch.mean( + torch.tensor(aggregate_metrics["first_token_per_sec"]) + ).item() + avg_next_tokens_sec = torch.mean( + torch.tensor(aggregate_metrics["next_tokens_per_sec"]) + ).item() + + if not ( + torch.isnan(torch.tensor(avg_tokens_sec)) + or torch.isnan(torch.tensor(avg_first_token_sec)) + or torch.isnan(torch.tensor(avg_next_tokens_sec)) + ): + print( + f"\n Average tokens/sec (total): {avg_tokens_sec:.2f} \ + \nAverage tokens/sec (first token): {avg_first_token_sec:.2f} \ + \nAverage tokens/sec (next tokens): {avg_next_tokens_sec:.2f} \n\ " - ) + ) if torch.cuda.is_available(): print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") diff --git a/torchchat/model.py b/torchchat/model.py index 11f3dc167..2a3b9f12f 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -94,7 +94,7 @@ def __init__( self.encoder = encoder self.decoder = decoder - # esclate the embedding layer outside decoder llava model need to fuse + # escalate the embedding layer outside decoder llava model need to fuse # the text and image embedding together before passing to decoder. self.tok_embeddings = getattr(self.decoder, token_embedding_name) diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 72a6dfc9b..99fd82fe8 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -388,6 +388,8 @@ def callback(x, *, done_generating=False): device_sync(device=self.builder_args.device) + buffer = [] + ILLEGAL_CHAR = '\ufffd' # Process each token, metrics tuple yielded by Generator.generate. for y, _ in self.generate( model=self.model, @@ -413,10 +415,15 @@ def callback(x, *, done_generating=False): break y = y.view(-1) + buffer.append(y.item()) # Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token. content = "".join( - self.tokenizer.decode([self.tokenizer.encode(".")[0]] + y.tolist())[1:] + self.tokenizer.decode([self.tokenizer.encode(".")[0]] + buffer)[1:] ) + # Skip content while illegal characters appear. + if ILLEGAL_CHAR in content: + continue + buffer.clear() # Package the sequence into a CompletionChunkResponse and yield it. chunk_delta = ChunkDelta( diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index 005bb6ef2..2685ec2f3 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -13,18 +13,31 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple -import torch ########################################################################## ### unpack packed weights ### +class _LazyImportTorch: + """This is a wrapper around the import of torch that only performs the + import when an actual attribute is needed off of torch. + """ + @staticmethod + def __getattribute__(name: str) -> Any: + import torch + return getattr(torch, name) + + +# Alias torch to the lazy import +torch = _LazyImportTorch() + + def unpack_packed_weights( packed_weights: Dict[str, Any], packed_linear: Callable, - input_dtype: torch.dtype, + input_dtype: "torch.dtype", unpacked_dims: Tuple, -) -> torch.Tensor: +) -> "torch.Tensor": """Given a packed weight matrix `packed_weights`, a Callable implementing a packed linear function for the packed format, and the unpacked dimensions of the weights, recreate the unpacked weight @@ -169,26 +182,27 @@ def name_to_dtype(name, device): return torch.bfloat16 try: - return name_to_dtype_dict[name] + return _name_to_dtype_dict[name]() except KeyError: raise RuntimeError(f"unsupported dtype name {name} specified") def allowable_dtype_names() -> List[str]: - return name_to_dtype_dict.keys() - - -name_to_dtype_dict = { - "fp32": torch.float, - "fp16": torch.float16, - "bf16": torch.bfloat16, - "float": torch.float, - "half": torch.float16, - "float32": torch.float, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "fast": None, - "fast16": None, + return _name_to_dtype_dict.keys() + + +# NOTE: values are wrapped in lambdas to avoid proactive imports for torch +_name_to_dtype_dict = { + "fp32": lambda: torch.float, + "fp16": lambda: torch.float16, + "bf16": lambda: torch.bfloat16, + "float": lambda: torch.float, + "half": lambda: torch.float16, + "float32": lambda: torch.float, + "float16": lambda: torch.float16, + "bfloat16": lambda: torch.bfloat16, + "fast": lambda: None, + "fast16": lambda: None, }