@@ -1534,26 +1534,49 @@ def forward(self, x):
15341534 self .assertEqual (len (program .constant_buffer [1 ].storage ), 8 )
15351535
15361536 def test_emit_lifted_tensor_constant (self ) -> None :
1537- class LiftedConstants (nn .Module ):
1537+ class LiftedTensorConstants (nn .Module ):
15381538 def __init__ (self ):
15391539 super ().__init__ ()
15401540
15411541 def forward (self , x ):
15421542 x = x * torch .tensor ([[4 , 3 ], [1 , 2 ], [5 , 6 ]], dtype = torch .float )
15431543 return x
15441544
1545- model = LiftedConstants ()
1545+ model = LiftedTensorConstants ()
1546+ # Specify that we want to move non-lifted constants to external file
1547+ et_cfg = ExecutorchBackendConfig (external_constants = True )
1548+ program = to_edge (
1549+ export (model , (torch .ones (3 , 2 ),), strict = True )
1550+ ).to_executorch (et_cfg )
1551+ program = program ._emitter_output .program
1552+ exec_plan = program .execution_plan [0 ]
1553+ # There should only be 1 input to this model.
1554+ self .assertEqual (len (exec_plan .inputs ), 1 )
1555+ self .assertEqual (len (program .constant_buffer ), 2 )
1556+ self .assertEqual (len (program .constant_buffer [1 ].storage ), 24 )
15461557
1558+ def test_emit_lifted_constant (self ) -> None :
1559+ class LiftedConstants (nn .Module ):
1560+ def __init__ (self ):
1561+ super ().__init__ ()
1562+
1563+ def forward (self , x ):
1564+ x = x + 1
1565+ return x
1566+
1567+ model = LiftedConstants ()
1568+ # Specify that we want to move non-lifted constants to external file
1569+ et_cfg = ExecutorchBackendConfig (external_constants = True )
15471570 program = to_edge (
15481571 export (model , (torch .ones (3 , 2 ),), strict = True )
1549- ).to_executorch ()
1572+ ).to_executorch (et_cfg )
15501573
15511574 program = program ._emitter_output .program
15521575 exec_plan = program .execution_plan [0 ]
15531576 # There should only be 1 input to this model.
15541577 self .assertEqual (len (exec_plan .inputs ), 1 )
15551578 self .assertEqual (len (program .constant_buffer ), 2 )
1556- self .assertEqual (len (program .constant_buffer [1 ].storage ), 24 )
1579+ self .assertEqual (len (program .constant_buffer [1 ].storage ), 8 )
15571580
15581581 def test_mutable_buffers (self ) -> None :
15591582 def count_copies (gm : torch .fx .GraphModule ) -> int :
0 commit comments