-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Labels
Description
CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 124928 --budget 4096 \
--chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.65s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00, 3.18it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 124928
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################
[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
Traceback (most recent call last):
File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in <module>
graph_engine.initialize_cuda_graph(gamma, probs=True, temperature=temperature, top_p=top_p)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 144, in initialize_cuda_graph
self.callables[gamma_offset] = draft_run_capture_graph(
File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 83, in draft_run_capture_graph
static_logits = engine.draft_run(input_ids=static_input_ids, gamma_offset=gamma_offset, probs=probs, temperature=temperature, top_p=top_p)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 54, in draft_run
logits = self.draft(input_ids=input_ids, kv_cache=self.draft_cache, graph_cache=self.draft_cache, gamma_offset=gamma_offset).logits
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 340, in forward
outputs = self.model(
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 301, in forward
layer_outputs = decoder_layer(
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 220, in forward
hidden_states = self.self_attn(
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 141, in forward
query_states = self.q_proj(hidden_states)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU