File tree Expand file tree Collapse file tree 2 files changed +10
-0
lines changed Expand file tree Collapse file tree 2 files changed +10
-0
lines changed Original file line number Diff line number Diff line change 1818from executorch .examples .models .llama .llama_transformer import Transformer
1919
2020from executorch .examples .models .llama .model_args import ModelArgs
21+ from torchao .utils import TorchAOBaseTensor
2122
2223try :
2324 from .fairseq2 import convert_to_llama_checkpoint
@@ -257,6 +258,9 @@ def __init__(self, **kwargs):
257258 strict = False ,
258259 assign = True ,
259260 ) # self.model_ = Transformer(gptconf)
261+ for param in self .model_ .parameters ():
262+ if isinstance (param , TorchAOBaseTensor ):
263+ param .requires_grad = False
260264 else :
261265 print ("Checkpoint not provided, defaulting weights to zeros." )
262266 self .model_ .to_empty (device = "cpu" )
Original file line number Diff line number Diff line change 4141from torch .ao .quantization .quantizer .composable_quantizer import ComposableQuantizer
4242from torch .export import export_for_training , ExportedProgram
4343from torch .nn .attention import SDPBackend
44+ from torchao .utils import unwrap_tensor_subclass
4445
4546FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
4647logging .basicConfig (level = logging .INFO , format = FORMAT )
@@ -199,6 +200,11 @@ def _get_edge_config(self) -> EdgeCompileConfig:
199200 return edge_config
200201
201202 def _export (self , module : Optional [torch .nn .Module ] = None ) -> ExportedProgram :
203+ if module is not None :
204+ unwrap_tensor_subclass (module )
205+ else :
206+ unwrap_tensor_subclass (self .model )
207+
202208 dynamic_shape = self ._get_dynamic_shape ()
203209 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
204210 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
You can’t perform that action at this time.
0 commit comments