@@ -40,16 +40,14 @@ class Llama3_2Decoder(EagerModelBase):
4040
4141 def __init__ (self , ** kwargs ):
4242 # Set member vars from kwargs.
43- self .max_seq_len = kwargs .get ("max_seq_len" , 8192 )
43+ self .max_seq_len = kwargs .get ("max_seq_len" , 8192 ) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
4444 self .encoder_max_seq_len = kwargs .get (
4545 "encoder_max_seq_len" , int (4 * (448 / 14 ) ** 2 + 1 )
46- )
46+ ) # Same as above.
4747 self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
4848 self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
4949 self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
50- # TODO: enable kv cache with TransformerDecoder's setup_cache().
5150 self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
52- self .use_sdpa_with_kv_cache = kwargs .get ("use_sdpa_with_kv_cache" , False )
5351 self .verbose = kwargs .get ("verbose" , False )
5452 self .args = kwargs .get ("args" , None )
5553
@@ -60,6 +58,14 @@ def __init__(self, **kwargs):
6058 checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
6159 params_path = kwargs .get ("params" , ckpt_dir / "demo_config.json" )
6260
61+ self .causal_mask = torch .tril (
62+ torch .ones (
63+ size = (self .max_seq_len , self .max_seq_len ),
64+ dtype = torch .bool ,
65+ )
66+ )
67+ self .input_pos = torch .arange (self .max_seq_len )
68+
6369 # Load checkpoint and params.
6470 device = "cpu"
6571 if checkpoint_dir is not None :
@@ -126,22 +132,30 @@ def __init__(self, **kwargs):
126132
127133 self .model_ = prune_output_vocab (self .model_ , output_prune_map )
128134
135+ # if self.use_kv_cache:
136+ # print("Setting up KV cache on the model...")
137+ # self.model_.setup_caches(
138+ # batch_size=1,
139+ # dtype=self.dtype,
140+ # )
141+
129142 def get_eager_model (self ) -> torch .nn .Module :
130143 if self .dtype :
131144 return self .model_ .to (self .dtype )
132145 else :
133146 return self .model_ .to (torch .float16 )
134147
135148 def get_example_inputs (self ):
136- return (torch .ones (1 , 64 , dtype = torch .long ),) # positional inputs
149+ return (torch .ones (1 , 64 , dtype = torch .long ),)
137150
138151 def get_example_kwarg_inputs (self ):
139152 # TODO: add input_pos and mask when after making cache work.
140153 return {
141- # "mask": None,
154+ # "mask": self.causal_mask[ None, 64, None, :] ,
142155 # "encoder_input": None,
143156 # "encoder_mask": None,
144157 # "input_pos": torch.ones(64, dtype=torch.long),
158+ # input_pos: self.input_pos[None, 64]
145159 }
146160
147161 def get_dynamic_shapes (self ):
0 commit comments