Skip to content

Commit c0125d7

Browse files
authored
Add files via upload
1 parent 7eca7d0 commit c0125d7

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

generate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,8 @@ def main(
314314
torch.manual_seed(1234)
315315
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
316316
if compile:
317-
# MKG
318-
# if is_speculative and use_tp:
319-
# torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
317+
if is_speculative and use_tp: # and ("cuda" in device):
318+
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
320319

321320
if is_speculative:
322321
global model_forward, logits_to_prob

0 commit comments

Comments
 (0)