2222
2323BATCH_SIZE = 1
2424
25- import argparse
26-
27- parser = argparse .ArgumentParser ()
28- parser .add_argument (
29- "--hf_auth_token" , type = str , help = "The Hugging Face auth token, required"
30- )
31- parser .add_argument ("--compile_to" , type = str , help = "torch, linalg, vmfb" )
32- parser .add_argument (
33- "--hf_model_name" ,
34- type = str ,
35- help = "HF model name" ,
36- default = "Trelis/Llama-2-7b-chat-hf-function-calling-v2" ,
37- )
38- parser .add_argument ("--quantization" , type = str , default = "unquantized" )
39- parser .add_argument ("--external_weight_file" , type = str , default = "" )
40- parser .add_argument (
41- "--vmfb_path" , type = str , default = None , help = "Path/name to store compiled vmfb."
42- )
43- parser .add_argument (
44- "--external_weights" ,
45- type = str ,
46- default = None ,
47- help = "saves ir/vmfb without global weights for size and readability, options [gguf, safetensors]" ,
48- )
49- parser .add_argument (
50- "--precision" , type = str , default = "fp16" , help = "dtype of model [f16, f32]"
51- )
52- parser .add_argument (
53- "--device" , type = str , default = "llvm-cpu" , help = "llvm-cpu, cuda, vulkan, rocm"
54- )
55- # TODO: Bring in detection for target triple
56- parser .add_argument (
57- "--iree_target_triple" ,
58- type = str ,
59- default = "host" ,
60- help = "Specify vulkan target triple or rocm/cuda target device." ,
61- )
62- parser .add_argument ("--vulkan_max_allocation" , type = str , default = "4294967296" )
63- parser .add_argument (
64- "--streaming_llm" ,
65- action = "store_true" ,
66- help = "Compile LLM with StreamingLLM optimizations" ,
67- )
68- parser .add_argument (
69- "--decomp_attn" ,
70- action = "store_true" ,
71- help = "Decompose attention ops at fx graph level." ,
72- )
73-
7425
7526def generate_schema (num_layers ):
7627 null = None
@@ -519,51 +470,31 @@ def evict_kvcache_space(self):
519470}
520471
521472
522- class StatelessLlamaPipeline :
473+ class StatelessLlama :
523474 def __init__ (
524475 self ,
525476 hf_model_name : str ,
526- scheduler_id : str ,
527- height : int ,
528- width : int ,
529477 precision : str ,
530- max_length : int ,
531- batch_size : int ,
532- num_inference_steps : int ,
533478 device : str ,
534479 iree_target_triple : str ,
535480 ireec_flags : list = [],
536- attn_spec : str = None ,
537- decomp_attn : bool = False ,
538481 pipeline_dir : str | Path = "./shark_vmfbs" ,
539482 external_weights_dir : str | Path = "./shark_weights" ,
540483 external_weights : str = "safetensors" ,
541- custom_vae : str = None ,
542- vae_decomp_attn : bool = True ,
543484 hf_auth_token : str = None ,
485+ streaming_llm : bool = False ,
544486 ):
545487 self .hf_model_name = hf_model_name
546488 self .iree_dtype = "float32" if precision == "fp32" else "float16"
547489 self .torch_dtype = torch .float32 if precision == "fp32" else torch .float16
548490 self .cpu_scheduling = True
549- self .scheduler_id = scheduler_id
550- self .height = height
551- self .width = width
552491 self .precision = precision
553- self .max_length = max_length
554- self .model_max_length = max_length
555- self .batch_size = batch_size
556- self .num_inference_steps = num_inference_steps
557492 self .device = device
558493 self .iree_target_triple = iree_target_triple
559494 self .ireec_flags = ireec_flags
560- self .attn_spec = attn_spec
561- self .decomp_attn = decomp_attn
562495 self .pipeline_dir = pipeline_dir
563496 self .external_weights_dir = external_weights_dir
564497 self .external_weights = external_weights
565- self .custom_vae = custom_vae
566- self .vae_decomp_attn = vae_decomp_attn
567498
568499 self .first_input = True
569500 self .max_tokens = llm_model_map [self .hf_model_name ]["max_tokens" ]
@@ -582,10 +513,11 @@ def __init__(
582513 )
583514 self .model = None
584515 self .hf_auth_token = hf_auth_token
516+ self .streaming_llm = streaming_llm
585517
586518 # FILE MANAGEMENT AND PIPELINE SETUP
587519
588- def check_prepared (
520+ def prepare_pipeline (
589521 self ,
590522 mlir : str ,
591523 vmfb : str ,
@@ -660,8 +592,8 @@ def export(
660592 weights_only : bool = False ,
661593 ):
662594 safe_name = self .hf_model_name .replace ("-" , "_" ).replace ("/" , "_" )
663- # if self.streaming_llm:
664- safe_name += "_streaming"
595+ if self .streaming_llm :
596+ safe_name += "_streaming"
665597
666598 if not os .path .exists (self .pipeline_dir ):
667599 os .makedirs (self .pipeline_dir )
@@ -698,7 +630,7 @@ def export(
698630 device = self .device ,
699631 target_triple = self .iree_target_triple ,
700632 vulkan_max_allocation = None ,
701- streaming_llm = True ,
633+ streaming_llm = self . streaming_llm ,
702634 vmfb_path = os .path .join (self .pipeline_dir , safe_name + ".vmfb" ),
703635 upload_ir = False ,
704636 mod = None ,
@@ -732,9 +664,12 @@ def format_out(results):
732664
733665 history = []
734666 for iter in range (self .max_tokens ):
735- # if self.streaming_llm:
736- token_slice = max (self .prev_token_len - 1 , 0 )
737- input_tensor = input_tensor [:, token_slice :]
667+ if self .streaming_llm :
668+ token_slice = max (self .prev_token_len - 1 , 0 )
669+ input_tensor = input_tensor [:, token_slice :]
670+ else :
671+ # TODO
672+ pass
738673 # if self.streaming_llm and self.model["get_seq_step"]() > 600:
739674 if self .model ["get_seq_step" ]() > 600 :
740675 print ("Evicting cache space!" )
@@ -743,7 +678,7 @@ def format_out(results):
743678 device_inputs = [
744679 ireert .asdevicearray (self .device , input_tensor )
745680 ]
746- if self .first_input : # or not self.streaming_llm:
681+ if self .first_input or not self .streaming_llm :
747682 st_time = time .time ()
748683 token = self .model ["run_initialize" ](* device_inputs )
749684 total_time = time .time () - st_time
@@ -820,33 +755,17 @@ def format_out(results):
820755 if not args .external_weights_dir and args .external_weights :
821756 args .external_weights_dir = args .pipeline_dir
822757
823- sd_pipe = StatelessLlamaPipeline (
758+ llama = StatelessLlama (
824759 args .hf_model_name ,
825- args .scheduler_id ,
826- args .height ,
827- args .width ,
828760 args .precision ,
829- args .max_length ,
830- args .batch_size ,
831- args .num_inference_steps ,
832761 args .device ,
833762 args .iree_target_triple ,
834763 flags ,
835- args .attn_spec ,
836- args .decomp_attn ,
837764 args .pipeline_dir ,
838765 args .external_weights_dir ,
839766 args .external_weights ,
840- args .vae_decomp_attn ,
841767 args .hf_auth_token ,
768+ True ,
842769 )
843- vmfb , weight = sd_pipe .check_prepared (mlir , vmfb , weight , interactive = False , quantization = "int4" )
844- sd_pipe .load_pipeline (vmfb , weight , args .rt_device , args .compiled_pipeline )
845- sd_pipe .generate_images (
846- args .prompt ,
847- args .negative_prompt ,
848- args .batch_count ,
849- args .guidance_scale ,
850- args .seed ,
851- False ,
852- )
770+ vmfb , weight = llama .prepare_pipeline (mlir , vmfb , weight , interactive = False , quantization = "int4" )
771+ llama .load_pipeline (vmfb , weight , args .rt_device , args .compiled_pipeline )
0 commit comments