Skip to content

Out of memory on H800 #7

@Lucas-TY

Description

@Lucas-TY
CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 124928 --budget 4096 \
 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.65s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.18it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 124928
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################

[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
Traceback (most recent call last):
  File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in <module>
    graph_engine.initialize_cuda_graph(gamma, probs=True, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 144, in initialize_cuda_graph
    self.callables[gamma_offset] = draft_run_capture_graph(
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 83, in draft_run_capture_graph
    static_logits = engine.draft_run(input_ids=static_input_ids, gamma_offset=gamma_offset, probs=probs, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 54, in draft_run
    logits = self.draft(input_ids=input_ids, kv_cache=self.draft_cache, graph_cache=self.draft_cache, gamma_offset=gamma_offset).logits
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 340, in forward
    outputs = self.model(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 301, in forward
    layer_outputs = decoder_layer(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 220, in forward
    hidden_states = self.self_attn(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 141, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions