We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7eca7d0 commit c0125d7Copy full SHA for c0125d7
generate.py
@@ -314,9 +314,8 @@ def main(
314
torch.manual_seed(1234)
315
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
316
if compile:
317
- # MKG
318
- # if is_speculative and use_tp:
319
- # torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
+ if is_speculative and use_tp: # and ("cuda" in device):
+ torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
320
321
if is_speculative:
322
global model_forward, logits_to_prob
0 commit comments