@@ -46,7 +46,12 @@ def vela_bin_pack_io(prefix, data):
4646# Output via Vela to binary stream for ArmBackendEthosU
4747# WARNING: Do not change this without changing VelaBinStream.cpp as that
4848# function consumes this format and the two need to align.
49- def vela_compile (tosa_flatbuffer : bytes , args : List [str ], verbose : bool = False ):
49+ def vela_compile (
50+ tosa_flatbuffer : bytes ,
51+ args : List [str ],
52+ verbose : bool = False ,
53+ intermediate_path : str | None = None ,
54+ ):
5055 """
5156 Compile a TOSA graph to a binary stream for ArmBackendEthosU using Vela.
5257 """
@@ -55,14 +60,14 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False)
5560 "ethos-u-vela pip package couldn't be imported. Make sure it's installed!"
5661 )
5762
58- with tempfile . TemporaryDirectory () as tmpdir :
63+ def run ( dir : str ) -> bytes :
5964 tosaname = "out.tosa"
60- tosa_path = os .path .join (tmpdir , tosaname )
65+ tosa_path = os .path .join (dir , tosaname )
6166 with open (tosa_path , "wb" ) as f :
6267 f .write (tosa_flatbuffer )
6368
6469 # invoke vela
65- output_dir = os .path .join (tmpdir , "output" )
70+ output_dir = os .path .join (dir , "output" )
6671 args .append (f"--output-dir={ output_dir } " )
6772 args .append (tosa_path )
6873 if verbose :
@@ -72,9 +77,9 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False)
7277 if any ("ethos-u85" in arg for arg in args ) or any (
7378 "debug-force-regor" in arg for arg in args
7479 ):
75- np_path = os .path .join (tmpdir , "output" , "out_vela.npz" )
80+ np_path = os .path .join (dir , "output" , "out_vela.npz" )
7681 else :
77- np_path = os .path .join (tmpdir , "output" , "out_sg0_vela.npz" )
82+ np_path = os .path .join (dir , "output" , "out_sg0_vela.npz" )
7883
7984 blocks = b""
8085 with np .load (np_path , allow_pickle = False ) as data :
@@ -122,3 +127,9 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False)
122127 blocks = blocks + block
123128
124129 return blocks
130+
131+ if intermediate_path is not None :
132+ return run (intermediate_path )
133+ else :
134+ with tempfile .TemporaryDirectory () as tmpdir :
135+ return run (tmpdir )
0 commit comments