Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
4a7dab8
[AOTI] Remove the original model weights in Python deployment (#1337)
desertfire Nov 6, 2024
ac02ffb
Minor code cleanups in generate.py and model.py (#1348)
swolchok Nov 7, 2024
170581a
feat(fast cli): Import torch lazily in all places used by the CLI tha…
gabe-l-hart Nov 7, 2024
743e6f3
Fix error: characters can not be displayed normally in chinese (#1342)
wjunLu Nov 7, 2024
e30aaa0
Update contributor channel name (#1354)
Gasoonjia Nov 7, 2024
2fcc37c
fix: enforce python version install requirements (#1345)
leseb Nov 12, 2024
0f58543
Remove last references to use_distributed argument (#1353)
mreso Nov 13, 2024
fe257fd
Add cstdint to tokenizer (missing include) (#1339)
byjlw Nov 13, 2024
0b385d3
Setup a SIGINT handler to gracefully exit the program once the user p…
leseb Nov 13, 2024
93f713f
Update cli.py to make --device/--dtype pre-empt quantize dict-specifi…
mikekgfb Nov 13, 2024
6eae887
Update Caching logic to only trigger on the first inference sample (#…
Jack-Khuu Nov 13, 2024
2cf1a17
Minor typo + Update install_requirements.sh to support python 3.10 >=…
Jack-Khuu Nov 13, 2024
ed0fb30
fix: Remove dup gguf dependency (#1371)
leseb Nov 14, 2024
4697764
Bug Fix: Check for explicit cli device (fast) (#1374)
Jack-Khuu Nov 14, 2024
d7b681a
fix: do not print perf stat when NaN (#1375)
leseb Nov 15, 2024
5da240a
fix: Fail gracefully when "model" arg is missing when downloading (#1…
leseb Nov 16, 2024
aee6487
fix: allow multiple weight mapping files for mistral
leseb Nov 6, 2024
295ae2a
fix(download): Fix safetensors/bin/pth download logic
gabe-l-hart Nov 6, 2024
5747a71
fix(convert hf): Better logic to handle multiple weight mapping files
gabe-l-hart Nov 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ We really value our community and the contributions made by our wonderful users.

To connect with us and other community members, we invite you to join our Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can:
* Head to the `#torchchat-general` channel for general questions, discussion, and community support.
* Join the `#torchchat-contribution` channel if you're interested in contributing directly to project development.
* Join the `#torchchat-contributors` channel if you're interested in contributing directly to project development.

Looking forward to discussing with you about torchchat future!

Expand Down
31 changes: 23 additions & 8 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,41 @@ set -eou pipefail

# Install required python dependencies for developing
# Dependencies are defined in .pyproject.toml
PYTHON_EXECUTABLE=${PYTHON_EXECUTABLE:-python}
if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]];
if [ -z "${PYTHON_EXECUTABLE:-}" ];
then
PYTHON_EXECUTABLE=python3
if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]];
then
PYTHON_EXECUTABLE=python3
else
PYTHON_EXECUTABLE=python
fi
fi

# Check python version. Expect 3.10.x or 3.11.x
printf "import sys\nif sys.version_info.major != 3 or sys.version_info.minor < 10 :\n\tprint('Please use Python >=3.10');sys.exit(1)\n" | $PYTHON_EXECUTABLE
if [[ $? -ne 0 ]]
echo "Using python executable: $PYTHON_EXECUTABLE"

PYTHON_SYS_VERSION="$($PYTHON_EXECUTABLE -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")"
# Check python version. Expect at least 3.10.x
if ! $PYTHON_EXECUTABLE -c "
import sys
if sys.version_info < (3, 10):
sys.exit(1)
";
then
echo "Python version must be at least 3.10.x. Detected version: $PYTHON_SYS_VERSION"
exit 1
fi

if [[ "$PYTHON_EXECUTABLE" == "python" ]];
then
PIP_EXECUTABLE=pip
else
elif [[ "$PYTHON_EXECUTABLE" == "python3" ]];
then
PIP_EXECUTABLE=pip3
else
PIP_EXECUTABLE=pip${PYTHON_SYS_VERSION}
fi

echo "Using pip executable: $PIP_EXECUTABLE"

#
# First install requirements in install/requirements.txt. Older torch may be
# installed from the dependency of other models. It will be overridden by
Expand Down
1 change: 0 additions & 1 deletion install/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ snakeviz
sentencepiece
# numpy version range required by GGUF util
numpy >= 1.17, < 2.0
gguf
blobfile
tomli >= 1.1.0 ; python_version < "3.11"
openai
Expand Down
1 change: 1 addition & 0 deletions tokenizer/base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#pragma once

#include <cassert>
#include <cstdint>
#include <string>
#include <string_view>

Expand Down
10 changes: 9 additions & 1 deletion torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import argparse
import logging
import subprocess
import signal
import sys

# MPS ops missing with Multimodal torchtune
Expand All @@ -25,7 +25,15 @@
default_device = "cpu"


def signal_handler(sig, frame):
print("\nInterrupted by user. Bye!\n")
sys.exit(0)


if __name__ == "__main__":
# Set the signal handler for SIGINT
signal.signal(signal.SIGINT, signal_handler)

# Initialize the top-level parser
parser = argparse.ArgumentParser(
prog="torchchat",
Expand Down
94 changes: 21 additions & 73 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@
import torch._inductor.config
import torch.nn as nn

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.utils.distributed import get_free_port

from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama

from torchchat.model import Model, ModelArgs, ModelType

from torchchat.model_config.model_config import resolve_model_config
Expand Down Expand Up @@ -464,77 +458,20 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
return model


def _maybe_init_distributed(
builder_args: BuilderArgs,
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
"""
Initialize distributed related setups if the user specified
using distributed inference. If not, this is a no-op.

Args:
builder_args (:class:`BuilderArgs`):
Command args for model building.
Returns:
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
- The first element is an optional DeviceMesh object,
which which describes the mesh topology of devices for the DTensor.
- The second element is an optional ParallelDims object,
which represents the parallel dimensions configuration.
"""
if not builder_args.use_distributed:
return None, None
dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line

world_mesh, parallel_dims = launch_distributed(dist_config)

assert (
world_mesh is not None and parallel_dims is not None
), f"failed to launch distributed using {dist_config}"

return world_mesh, parallel_dims


def _maybe_parallelize_model(
model: nn.Module,
builder_args: BuilderArgs,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
) -> nn.Module:
"""
We parallelize the module and load the distributed checkpoint to the model
if the user specifies using distributed inference. If not, this is a no-op.

Args:
model (:class:`nn.Module`):
Module to be parallelized.
builder_args (:class:`BuilderArgs`):
Command args for model building.
world_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
parallel_dims (:class:`ParallelDims`):
Object which represents the parallel dimensions configuration.
Returns:
A :class:`nn.Module` object which is parallelized and checkpoint loaded
if the user specifies using distributed inference.
"""
if world_mesh is None:
return model
assert parallel_dims is not None
print("Applying model parallel to model ...")
parallelize_llama(model, world_mesh, parallel_dims)
return load_checkpoints_to_model(model, builder_args, world_mesh)


def _load_model(builder_args: BuilderArgs) -> Model:
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
if builder_args.gguf_path:
model = _load_model_gguf(builder_args)
# elif builder_args.use_distributed:
# model = _init_model_on_meta_device(builder_args)
else:
model = _load_model_default(builder_args)
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)

if builder_args.dso_path or builder_args.aoti_package_path:
# AOTI-compoiled model will load its own weights.
# Release weights here to avoid OOM
import gc
if hasattr(model, "model"):
model.model = None
gc.collect()
torch.cuda.empty_cache()

model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()
Expand Down Expand Up @@ -584,6 +521,12 @@ def _initialize_model(
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
# Using cpp runner to run AOTI compiled model is recommended.

def do_nothing(max_batch_size, max_seq_length):
pass
model.setup_caches = do_nothing

model.forward = torch._export.aot_load(
str(builder_args.dso_path.absolute()), builder_args.device
)
Expand Down Expand Up @@ -617,6 +560,11 @@ def _initialize_model(
aoti_compiled_model = load_package(
str(builder_args.aoti_package_path.absolute())
)

def do_nothing(max_batch_size, max_seq_length):
pass
model.setup_caches = do_nothing

model.forward = aoti_compiled_model
metadata = aoti_compiled_model.get_metadata()
builder_args.device = metadata["AOTI_DEVICE_KEY"]
Expand Down Expand Up @@ -686,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
return "TikToken"
if tokenizers:
return "Tokenizers"
return "SentencePiece"
return "SentencePiece"
56 changes: 39 additions & 17 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,24 @@
# LICENSE file in the root directory of this source tree.

import argparse
import importlib.metadata
import json
import logging
import os
import sys
from pathlib import Path

import torch

from torchchat.cli.download import download_and_convert, is_model_downloaded

from torchchat.utils.build_utils import (
allowable_dtype_names,
allowable_params_table,
get_device_str,
)

logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)

default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
default_dtype = os.getenv("TORCHCHAT_PRECISION", "fast")

default_model_dir = Path(
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
).expanduser()
Expand All @@ -42,6 +40,9 @@

# Handle CLI arguments that are common to a majority of subcommands.
def check_args(args, verb: str) -> None:
# Local import to avoid unnecessary expensive imports
from torchchat.cli.download import download_and_convert, is_model_downloaded

# Handle model download. Skip this for download, since it has slightly
# different semantics.
if (
Expand Down Expand Up @@ -150,9 +151,9 @@ def _add_model_config_args(parser, verb: str) -> None:

model_config_parser.add_argument(
"--dtype",
default="fast",
default=None,
choices=allowable_dtype_names(),
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast",
help="Override the dtype of the model. Options: bf16, fp16, fp32, fast16, fast",
)
model_config_parser.add_argument(
"--quantize",
Expand All @@ -166,9 +167,9 @@ def _add_model_config_args(parser, verb: str) -> None:
model_config_parser.add_argument(
"--device",
type=str,
default=default_device,
default=None,
choices=["fast", "cpu", "cuda", "mps"],
help="Hardware device to use. Options: cpu, cuda, mps",
help="Hardware device to use. Options: fast, cpu, cuda, mps",
)


Expand Down Expand Up @@ -498,9 +499,10 @@ def _add_speculative_execution_args(parser) -> None:


def arg_init(args):
if not (torch.__version__ > "2.3"):
torch_version = importlib.metadata.version("torch")
if not torch_version or (torch_version <= "2.3"):
raise RuntimeError(
f"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
f"You are using PyTorch {torch_version}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
)

if sys.version_info.major != 3 or sys.version_info.minor < 10:
Expand All @@ -513,17 +515,34 @@ def arg_init(args):
if isinstance(args.quantize, str):
args.quantize = json.loads(args.quantize)

# if we specify dtype in quantization recipe, replicate it as args.dtype
args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype)
# if we specify dtype in quantization recipe, allow args.dtype top override if specified
if args.dtype is None:
args.dtype = args.quantize.get("precision", {}).get("dtype", default_dtype)
else:
precision_handler = args.quantize.get("precision", None)
if precision_handler:
if precision_handler["dtype"] != args.dtype:
print('overriding json-specified dtype {precision_handler["dtype"]} with cli dtype {args.dtype}')
precision_handler["dtype"] = args.dtype

if getattr(args, "output_pte_path", None):
if args.device not in ["cpu", "fast"]:
if args.device not in [None, "cpu", "fast"]:
raise RuntimeError("Device not supported by ExecuTorch")
args.device = "cpu"
else:
args.device = get_device_str(
args.quantize.get("executor", {}).get("accelerator", args.device)
)
# Localized import to minimize expensive imports
from torchchat.utils.build_utils import get_device_str

if args.device is None or args.device == "fast":
args.device = get_device_str(
args.quantize.get("executor", {}).get("accelerator", default_device)
)
else:
executor_handler = args.quantize.get("executor", None)
if executor_handler:
if executor_handler["accelerator"] != args.device:
print('overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}')
executor_handler["accelerator"] = args.device

if "mps" in args.device:
if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):
Expand All @@ -534,5 +553,8 @@ def arg_init(args):
vars(args)["compile_prefill"] = False

if hasattr(args, "seed") and args.seed:
# Localized import to minimize expensive imports
import torch

torch.manual_seed(args.seed)
return args
Loading