Skip to content

Commit da57fe3

Browse files
authored
[Llama] Use rocm ukernel when available + use num_layer for pkv. (#381)
Use ukernel to improve perf + fix #380. Additionally, added fix to stateless llama to handle non 32 size layer. Seems like currently our PKV value is based on number of attention head. This currently work because number of attn head happens to be number of layer for many models we are looking at. But once that assumption breaks, we will run into some issues with stateless llama. This PR also introduces fix for this minor bug.
1 parent c1dc94c commit da57fe3

File tree

2 files changed

+41
-21
lines changed

2 files changed

+41
-21
lines changed

python/turbine_models/custom_models/stateless_llama.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
33
import re
4+
import json
45

56
os.environ["TORCH_LOGS"] = "dynamic"
67
from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -61,19 +62,26 @@
6162
help="Compile LLM with StreamingLLM optimizations",
6263
)
6364

64-
# TODO (Dan): replace this with a file once I figure out paths on windows exe
65-
json_schema_64 = """
66-
[1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}]
67-
"""
6865

69-
json_schema_16 = """
70-
[1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}]
71-
"""
66+
def generate_schema(num_layers):
67+
null = None
68+
schema = [1, {"type": "builtins.tuple", "context": "null", "children_spec": []}]
69+
kv_schema_per_layer = {
70+
"type": "builtins.tuple",
71+
"context": "null",
72+
"children_spec": [
73+
{"type": null, "context": null, "children_spec": []},
74+
{"type": null, "context": null, "children_spec": []},
75+
],
76+
}
77+
for i in range(num_layers):
78+
schema[1]["children_spec"].append(kv_schema_per_layer)
79+
return json.dumps(schema)
7280

7381

74-
def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):
82+
def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim, num_layers):
7583
all_pkv_tensors = []
76-
for i in range(heads * 2):
84+
for i in range(num_layers * 2):
7785
# Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim]
7886
# Generates tensor<1 x 1 x seq_step x heads x hidden_dim>
7987
sliced = IREE.tensor_slice(
@@ -105,10 +113,8 @@ def export_transformer_model(
105113
torch_dtype=torch.float,
106114
token=hf_auth_token,
107115
)
108-
if mod.config.num_attention_heads == 8:
109-
state_schema = pytree.treespec_loads(json_schema_16)
110-
else:
111-
state_schema = pytree.treespec_loads(json_schema_64)
116+
schema_json = generate_schema(mod.config.num_hidden_layers)
117+
state_schema = pytree.treespec_loads(schema_json)
112118
if streaming_llm:
113119
enable_llama_pos_shift_attention(mod)
114120
dtype = torch.float32
@@ -121,12 +127,13 @@ def export_transformer_model(
121127
token=hf_auth_token,
122128
)
123129
# TODO: generate these values instead of magic numbers
130+
NUM_LAYERS = mod.config.num_hidden_layers
124131
HEADS = mod.config.num_attention_heads
125132
HIDDEN_DIM = int(mod.config.hidden_size / HEADS)
126133
BATCH_SIZE = 1
127134
MAX_STEP_SEQ = mod.config.max_position_embeddings - 1
128135
global_pkv = torch.zeros(
129-
size=(HEADS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
136+
size=(NUM_LAYERS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
130137
dtype=dtype,
131138
)
132139

@@ -161,7 +168,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
161168
self.global_seq_step = IREE.tensor_dim(
162169
state[0], 1
163170
) # ? dimension of arbitrarily 0th kv tensor
164-
for i in range(HEADS * 2):
171+
for i in range(NUM_LAYERS * 2):
165172
slice_of_state = IREE.tensor_reshape(
166173
state[i], 1, 1, self.global_seq_step, HEADS, HIDDEN_DIM
167174
)
@@ -172,7 +179,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
172179

173180
def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
174181
state_arg = slice_up_to_step(
175-
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
182+
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
176183
)
177184
forw_const = (
178185
[state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
@@ -183,7 +190,7 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
183190
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
184191
)
185192
token, *state_update = self.forward(x, *state_arg, constraints=forw_const)
186-
for i in range(HEADS * 2):
193+
for i in range(NUM_LAYERS * 2):
187194
update = IREE.tensor_reshape(
188195
state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM
189196
)
@@ -226,7 +233,7 @@ def run_cached_initialize(
226233
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
227234
):
228235
state_arg = slice_up_to_step(
229-
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
236+
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
230237
)
231238
forw_const = (
232239
[x.dynamic_dim(1) < MAX_STEP_SEQ]
@@ -243,7 +250,7 @@ def run_cached_initialize(
243250
len_of_new_tokens = IREE.tensor_dim(
244251
state[0], 1
245252
) # ? dimension of arbitrarily 0th kv tensor
246-
for i in range(HEADS * 2):
253+
for i in range(NUM_LAYERS * 2):
247254
slice_of_state = IREE.tensor_reshape(
248255
state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM
249256
)
@@ -278,7 +285,7 @@ def evict_kvcache_space(self):
278285
sink_size = 4
279286
window_size = 252
280287
most_recent_window = self.global_seq_step + (-window_size)
281-
for i in range(HEADS * 2):
288+
for i in range(NUM_LAYERS * 2):
282289
update_window_state = IREE.tensor_slice(
283290
self.global_state,
284291
i,
@@ -339,12 +346,14 @@ def evict_kvcache_space(self):
339346
[
340347
"--iree-rocm-target-chip=" + target_triple,
341348
"--iree-rocm-link-bc=true",
342-
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
343349
"--iree-vm-bytecode-module-strip-source-map=true",
344350
"--iree-opt-strip-assertions=true",
345351
"--iree-vm-target-truncate-unsupported-floats",
346352
]
347353
)
354+
ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"}
355+
if target_triple in ukernel_supported_arch:
356+
flags.extend(["--iree-rocm-enable-ukernels=argmax"])
348357
elif device == "cuda":
349358
flags.extend(
350359
[

0 commit comments

Comments
 (0)