11import os
22import sys
33import re
4+ import json
45
56os .environ ["TORCH_LOGS" ] = "dynamic"
67from transformers import AutoTokenizer , AutoModelForCausalLM
6162 help = "Compile LLM with StreamingLLM optimizations" ,
6263)
6364
64- # TODO (Dan): replace this with a file once I figure out paths on windows exe
65- json_schema_64 = """
66- [1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}]
67- """
6865
69- json_schema_16 = """
70- [1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}]
71- """
66+ def generate_schema (num_layers ):
67+ null = None
68+ schema = [1 , {"type" : "builtins.tuple" , "context" : "null" , "children_spec" : []}]
69+ kv_schema_per_layer = {
70+ "type" : "builtins.tuple" ,
71+ "context" : "null" ,
72+ "children_spec" : [
73+ {"type" : null , "context" : null , "children_spec" : []},
74+ {"type" : null , "context" : null , "children_spec" : []},
75+ ],
76+ }
77+ for i in range (num_layers ):
78+ schema [1 ]["children_spec" ].append (kv_schema_per_layer )
79+ return json .dumps (schema )
7280
7381
74- def slice_up_to_step (global_pkv , seq_step , heads , hidden_dim ):
82+ def slice_up_to_step (global_pkv , seq_step , heads , hidden_dim , num_layers ):
7583 all_pkv_tensors = []
76- for i in range (heads * 2 ):
84+ for i in range (num_layers * 2 ):
7785 # Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim]
7886 # Generates tensor<1 x 1 x seq_step x heads x hidden_dim>
7987 sliced = IREE .tensor_slice (
@@ -105,10 +113,8 @@ def export_transformer_model(
105113 torch_dtype = torch .float ,
106114 token = hf_auth_token ,
107115 )
108- if mod .config .num_attention_heads == 8 :
109- state_schema = pytree .treespec_loads (json_schema_16 )
110- else :
111- state_schema = pytree .treespec_loads (json_schema_64 )
116+ schema_json = generate_schema (mod .config .num_hidden_layers )
117+ state_schema = pytree .treespec_loads (schema_json )
112118 if streaming_llm :
113119 enable_llama_pos_shift_attention (mod )
114120 dtype = torch .float32
@@ -121,12 +127,13 @@ def export_transformer_model(
121127 token = hf_auth_token ,
122128 )
123129 # TODO: generate these values instead of magic numbers
130+ NUM_LAYERS = mod .config .num_hidden_layers
124131 HEADS = mod .config .num_attention_heads
125132 HIDDEN_DIM = int (mod .config .hidden_size / HEADS )
126133 BATCH_SIZE = 1
127134 MAX_STEP_SEQ = mod .config .max_position_embeddings - 1
128135 global_pkv = torch .zeros (
129- size = (HEADS * 2 , BATCH_SIZE , MAX_STEP_SEQ , HEADS , HIDDEN_DIM ),
136+ size = (NUM_LAYERS * 2 , BATCH_SIZE , MAX_STEP_SEQ , HEADS , HIDDEN_DIM ),
130137 dtype = dtype ,
131138 )
132139
@@ -161,7 +168,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
161168 self .global_seq_step = IREE .tensor_dim (
162169 state [0 ], 1
163170 ) # ? dimension of arbitrarily 0th kv tensor
164- for i in range (HEADS * 2 ):
171+ for i in range (NUM_LAYERS * 2 ):
165172 slice_of_state = IREE .tensor_reshape (
166173 state [i ], 1 , 1 , self .global_seq_step , HEADS , HIDDEN_DIM
167174 )
@@ -172,7 +179,7 @@ def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
172179
173180 def run_forward (self , x = AbstractTensor (1 , 1 , dtype = torch .int64 )):
174181 state_arg = slice_up_to_step (
175- self .global_state , self .global_seq_step , HEADS , HIDDEN_DIM
182+ self .global_state , self .global_seq_step , HEADS , HIDDEN_DIM , NUM_LAYERS
176183 )
177184 forw_const = (
178185 [state_arg [0 ].dynamic_dim (1 ) < MAX_STEP_SEQ ]
@@ -183,7 +190,7 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
183190 + [x .dynamic_dim (1 ) < MAX_STEP_SEQ for x in state_arg [1 :]]
184191 )
185192 token , * state_update = self .forward (x , * state_arg , constraints = forw_const )
186- for i in range (HEADS * 2 ):
193+ for i in range (NUM_LAYERS * 2 ):
187194 update = IREE .tensor_reshape (
188195 state_update [i ], 1 , 1 , 1 , HEADS , HIDDEN_DIM
189196 )
@@ -226,7 +233,7 @@ def run_cached_initialize(
226233 self , x = AbstractTensor (BATCH_SIZE , None , dtype = torch .int64 )
227234 ):
228235 state_arg = slice_up_to_step (
229- self .global_state , self .global_seq_step , HEADS , HIDDEN_DIM
236+ self .global_state , self .global_seq_step , HEADS , HIDDEN_DIM , NUM_LAYERS
230237 )
231238 forw_const = (
232239 [x .dynamic_dim (1 ) < MAX_STEP_SEQ ]
@@ -243,7 +250,7 @@ def run_cached_initialize(
243250 len_of_new_tokens = IREE .tensor_dim (
244251 state [0 ], 1
245252 ) # ? dimension of arbitrarily 0th kv tensor
246- for i in range (HEADS * 2 ):
253+ for i in range (NUM_LAYERS * 2 ):
247254 slice_of_state = IREE .tensor_reshape (
248255 state [i ], 1 , 1 , len_of_new_tokens , HEADS , HIDDEN_DIM
249256 )
@@ -278,7 +285,7 @@ def evict_kvcache_space(self):
278285 sink_size = 4
279286 window_size = 252
280287 most_recent_window = self .global_seq_step + (- window_size )
281- for i in range (HEADS * 2 ):
288+ for i in range (NUM_LAYERS * 2 ):
282289 update_window_state = IREE .tensor_slice (
283290 self .global_state ,
284291 i ,
@@ -339,12 +346,14 @@ def evict_kvcache_space(self):
339346 [
340347 "--iree-rocm-target-chip=" + target_triple ,
341348 "--iree-rocm-link-bc=true" ,
342- "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode" ,
343349 "--iree-vm-bytecode-module-strip-source-map=true" ,
344350 "--iree-opt-strip-assertions=true" ,
345351 "--iree-vm-target-truncate-unsupported-floats" ,
346352 ]
347353 )
354+ ukernel_supported_arch = {"gfx90a" , "gfx940" , "gfx1030" , "gfx1100" }
355+ if target_triple in ukernel_supported_arch :
356+ flags .extend (["--iree-rocm-enable-ukernels=argmax" ])
348357 elif device == "cuda" :
349358 flags .extend (
350359 [
0 commit comments