Skip to content

Commit 29e864b

Browse files
committed
Default device to CPU if CUDA not available in some arguments
1 parent 48328fb commit 29e864b

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def main(
259259
profile: Optional[Path] = None,
260260
draft_checkpoint_path: Optional[Path] = None,
261261
speculate_k: int = 5,
262-
device='cuda',
262+
device=('cuda' if torch.cuda.is_available() else 'cpu'),
263263
) -> None:
264264
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
265265
"""
@@ -414,7 +414,7 @@ def callback(x):
414414
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
415415
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
416416
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
417-
parser.add_argument('--device', type=str, default="cuda", help='Device to use')
417+
parser.add_argument('--device', type=str, default=('cuda' if torch.cuda.is_available() else 'cpu'), help='Device to use')
418418

419419
args = parser.parse_args()
420420
main(

quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def quantize(
619619
parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
620620
parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
621621
parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
622-
parser.add_argument('--device', type=str, default='cuda', help='device to use')
622+
parser.add_argument('--device', type=str, default=('cuda' if torch.cuda.is_available() else 'cpu'), help='device to use')
623623

624624
args = parser.parse_args()
625625
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label, args.device)

0 commit comments

Comments
 (0)