Skip to content

Commit d4dbc60

Browse files
committed
2 parents 52625f8 + f479b07 commit d4dbc60

File tree

4 files changed

+15
-8
lines changed

4 files changed

+15
-8
lines changed

generate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def device_sync(device):
1717
if "cuda" in device:
1818
torch.cuda.synchronize(device)
19-
elif "cpu" in device:
19+
elif ("cpu" in device) or ("mps" in device):
2020
pass
2121
else:
2222
print(f"device={device} is not yet suppported")
@@ -26,6 +26,7 @@ def device_sync(device):
2626
torch._inductor.config.triton.unique_kernel_names = True
2727
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2828

29+
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
2930

3031
# support running without installing as a package
3132
wd = Path(__file__).parent.parent.resolve()
@@ -206,7 +207,7 @@ def generate(
206207
}
207208
return seq, generate_stats
208209

209-
def encode_tokens(tokenizer, string, bos=True, device='cuda'):
210+
def encode_tokens(tokenizer, string, bos=True, device=default_device):
210211
tokens = tokenizer.encode(string)
211212
if bos:
212213
tokens = [tokenizer.bos_id()] + tokens
@@ -259,7 +260,7 @@ def main(
259260
profile: Optional[Path] = None,
260261
draft_checkpoint_path: Optional[Path] = None,
261262
speculate_k: int = 5,
262-
device='cuda',
263+
device=default_device,
263264
) -> None:
264265
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
265266
"""
@@ -414,7 +415,7 @@ def callback(x):
414415
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
415416
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
416417
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
417-
parser.add_argument('--device', type=str, default="cuda", help='Device to use')
418+
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
418419

419420
args = parser.parse_args()
420421
main(

mixtral-moe/scripts/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
1313
from huggingface_hub import snapshot_download
1414
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
1515
try:
16-
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
16+
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors")
1717
except HTTPError as e:
1818
if e.response.status_code == 401:
1919
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")

quantize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from model import Transformer
2121

22+
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
23+
2224
##### Quantization Primitives ######
2325

2426
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
@@ -539,7 +541,7 @@ def quantize(
539541
percdamp: float = .01,
540542
blocksize: int = 128,
541543
label: str = '',
542-
device: str = 'cuda',
544+
device: str = default_device,
543545
) -> None:
544546
assert checkpoint_path.is_file(), checkpoint_path
545547

@@ -619,7 +621,7 @@ def quantize(
619621
parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
620622
parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
621623
parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
622-
parser.add_argument('--device', type=str, default='cuda', help='device to use')
624+
parser.add_argument('--device', type=str, default=default_device, help='device to use')
623625

624626
args = parser.parse_args()
625627
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label, args.device)

tp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
import torch
1010
import torch.distributed as dist
1111
from torch import nn
12-
from torch.distributed import _functional_collectives as funcol
12+
if os.uname().sysname != "Darwin":
13+
from torch.distributed import _functional_collectives as funcol
14+
else:
15+
# Distributed is not supported on MacOS
16+
funcol = None
1317

1418
from model import Attention, FeedForward, Transformer
1519
from quantize import WeightOnlyInt4Linear

0 commit comments

Comments
 (0)