1010from typing import Any , Dict
1111
1212import torch
13-
14- from executorch .examples .models .model_base import EagerModelBase
15- from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
16- from torchtune .models .llama3_2_vision ._component_builders import llama3_2_vision_decoder
1713from executorch .examples .models .checkpoint import (
18- get_default_model_resource_dir ,
1914 get_checkpoint_dtype ,
15+ get_default_model_resource_dir ,
2016)
2117
18+ from executorch .examples .models .model_base import EagerModelBase
19+ from torchtune .models .llama3_2_vision ._component_builders import llama3_2_vision_decoder
20+ from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
21+
2222
2323def to_decoder_checkpoint (checkpoint : Dict [str , Any ]) -> Dict [str , Any ]:
2424 """
2525 Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains
2626 weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale".
2727 To load the text decoder on its own, the "decoder" prefix needs to be removed.
2828 """
29- return {"." .join (weight .split ("." )[1 :]): value for weight , value in checkpoint .items () if weight .startswith ("decoder" )}
29+ return {
30+ "." .join (weight .split ("." )[1 :]): value
31+ for weight , value in checkpoint .items ()
32+ if weight .startswith ("decoder" )
33+ }
34+
3035
3136class Llama3_2Decoder (EagerModelBase ):
3237 """
@@ -36,7 +41,9 @@ class Llama3_2Decoder(EagerModelBase):
3641 def __init__ (self , ** kwargs ):
3742 # Set member vars from kwargs.
3843 self .max_seq_len = kwargs .get ("max_seq_len" , 8192 )
39- self .encoder_max_seq_len = kwargs .get ("encoder_max_seq_len" , int (4 * (448 / 14 ) ** 2 + 1 ))
44+ self .encoder_max_seq_len = kwargs .get (
45+ "encoder_max_seq_len" , int (4 * (448 / 14 ) ** 2 + 1 )
46+ )
4047 self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
4148 self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
4249 self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
@@ -46,7 +53,6 @@ def __init__(self, **kwargs):
4653 self .verbose = kwargs .get ("verbose" , False )
4754 self .args = kwargs .get ("args" , None )
4855
49-
5056 ckpt_dir = get_default_model_resource_dir (__file__ )
5157 # Single checkpoint file.
5258 checkpoint_path = kwargs .get ("checkpoint" , ckpt_dir / "demo_rand_params.pth" )
@@ -57,7 +63,9 @@ def __init__(self, **kwargs):
5763 # Load checkpoint and params.
5864 device = "cpu"
5965 if checkpoint_dir is not None :
60- raise NotImplementedError ("Sharded checkpoint not yet supported for Llama3_2Decoder." )
66+ raise NotImplementedError (
67+ "Sharded checkpoint not yet supported for Llama3_2Decoder."
68+ )
6169 else :
6270 checkpoint = torch .load (checkpoint_path , map_location = device , mmap = True )
6371 checkpoint = llama3_vision_meta_to_tune (checkpoint )
@@ -107,7 +115,9 @@ def __init__(self, **kwargs):
107115 # Prune the output layer if output_prune_map is provided.
108116 output_prune_map = None
109117 if self .output_prune_map_path is not None :
110- from executorch .examples .models .llama2 .source_transformation .prune_output import prune_output_vocab
118+ from executorch .examples .models .llama2 .source_transformation .prune_output import (
119+ prune_output_vocab ,
120+ )
111121
112122 with open (self .output_prune_map_path , "r" ) as f :
113123 output_prune_map = json .load (f )
@@ -123,9 +133,7 @@ def get_eager_model(self) -> torch.nn.Module:
123133 return self .model_ .to (torch .float16 )
124134
125135 def get_example_inputs (self ):
126- return (
127- torch .ones (1 , 64 , dtype = torch .long ), # positional inputs
128- )
136+ return (torch .ones (1 , 64 , dtype = torch .long ),) # positional inputs
129137
130138 def get_example_kwarg_inputs (self ):
131139 # TODO: add input_pos and mask when after making cache work.
@@ -137,7 +145,7 @@ def get_example_kwarg_inputs(self):
137145 }
138146
139147 def get_dynamic_shapes (self ):
140- dim = torch .export .Dim ("token_dim" , min = 1 ,max = self .max_seq_len )
148+ dim = torch .export .Dim ("token_dim" , min = 1 , max = self .max_seq_len )
141149 dynamic_shapes = {
142150 "tokens" : {0 : 1 , 1 : dim },
143151 # "encoder_input": {0:1, 1:dim_enc, 2:4096},
@@ -146,4 +154,3 @@ def get_dynamic_shapes(self):
146154 # "input_pos" : {0: dim},
147155 }
148156 return dynamic_shapes
149-
0 commit comments