1- # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2-
31import torch
42
5- from executorch .exir import to_edge_transform_and_lower
63from executorch .devtools import BundledProgram
74
85from executorch .devtools .bundled_program .config import MethodTestCase , MethodTestSuite
96from executorch .devtools .bundled_program .serialize import (
107 serialize_from_bundled_program_to_flatbuffer ,
118)
9+
10+ from executorch .exir import to_edge_transform_and_lower
1211from torch .export import export , export_for_training
1312
1413# Step 1: ExecuTorch Program Export
@@ -17,8 +16,8 @@ class SampleModel(torch.nn.Module):
1716
1817 def __init__ (self ) -> None :
1918 super ().__init__ ()
20- self .register_buffer ('a' , 3 * torch .ones (2 , 2 , dtype = torch .int32 ))
21- self .register_buffer ('b' , 2 * torch .ones (2 , 2 , dtype = torch .int32 ))
19+ self .register_buffer ("a" , 3 * torch .ones (2 , 2 , dtype = torch .int32 ))
20+ self .register_buffer ("b" , 2 * torch .ones (2 , 2 , dtype = torch .int32 ))
2221
2322 def forward (self , x : torch .Tensor , q : torch .Tensor ) -> torch .Tensor :
2423 z = x .clone ()
@@ -76,7 +75,7 @@ def main() -> None:
7675 test_cases = [
7776 MethodTestCase (
7877 inputs = input ,
79- expected_outputs = (getattr (model , method_name )(* input ), ),
78+ expected_outputs = (getattr (model , method_name )(* input ),),
8079 )
8180 for input in inputs
8281 ],
0 commit comments