27
27
import concurrent .futures
28
28
import tempfile
29
29
import shutil
30
+ from typing import Any
30
31
31
32
import gin
33
+ import tensorflow as tf
32
34
33
35
from compiler_opt .rl import corpus
34
36
from compiler_opt .distributed import worker
35
37
from compiler_opt .rl import policy_saver
36
38
from compiler_opt .es import policy_utils
37
39
38
40
41
+ def _make_dirs_and_copy (old_file_path : str , new_file_path : str ):
42
+ tf .io .gfile .makedirs (os .path .dirname (new_file_path ))
43
+ tf .io .gfile .copy (old_file_path , new_file_path )
44
+
45
+
39
46
@gin .configurable
40
47
class RegallocTraceWorker (worker .Worker ):
41
48
"""A worker that produces rewards for a given regalloc policy.
@@ -53,9 +60,55 @@ def _setup_base_policy(self):
53
60
saver .save (self ._tf_base_temp_dir )
54
61
self ._tf_base_policy_path = os .path .join (self ._tf_base_temp_dir , "policy" )
55
62
56
- def __init__ (self , * , gin_config : str , clang_path : str ,
57
- basic_block_trace_model_path : str , thread_count : int ,
58
- corpus_path : str ):
63
+ def _copy_corpus (self , corpus_path : str ,
64
+ copy_corpus_locally_path : str | None ) -> None :
65
+ """Makes a local copy of the corpus if requested.
66
+
67
+ This function makes a local copy of the corpus by copying the remote
68
+ corpus to a user-specified directory.
69
+
70
+ Args:
71
+ corpus_path: The path to the remote corpus.
72
+ copy_corpus_locally: The local path to copy the corpus to.
73
+ """
74
+ # We use the tensorflow APIs below rather than the standard Python file
75
+ # APIs for compatibility with more filesystems.
76
+
77
+ if tf .io .gfile .exists (copy_corpus_locally_path ):
78
+ return
79
+
80
+ with tf .io .gfile .GFile (
81
+ os .path .join (corpus_path , "corpus_description.json" ),
82
+ "r" ) as corpus_description_file :
83
+ corpus_description : dict [str , Any ] = json .load (corpus_description_file )
84
+
85
+ file_extensions_to_copy = [".bc" , ".cmd" ]
86
+ if corpus_description ["has_thinlto" ]:
87
+ file_extensions_to_copy .append (".thinlto.bc" )
88
+
89
+ copy_futures = []
90
+ with concurrent .futures .ThreadPoolExecutor (self ._thread_count *
91
+ 5 ) as copy_thread_pool :
92
+ for module in corpus_description ["modules" ]:
93
+ for extension in file_extensions_to_copy :
94
+ current_path = os .path .join (corpus_path , module + extension )
95
+ new_path = os .path .join (copy_corpus_locally_path , module + extension )
96
+ copy_futures .append (
97
+ copy_thread_pool .submit (_make_dirs_and_copy , current_path ,
98
+ new_path ))
99
+
100
+ for copy_future in copy_futures :
101
+ if copy_future .exception () is not None :
102
+ raise copy_future .exception ()
103
+
104
+ def __init__ (self ,
105
+ * ,
106
+ gin_config : str ,
107
+ clang_path : str ,
108
+ basic_block_trace_model_path : str ,
109
+ thread_count : int ,
110
+ corpus_path : str ,
111
+ copy_corpus_locally_path : str | None = None ):
59
112
"""Initializes the RegallocTraceWorker class.
60
113
61
114
Args:
@@ -68,11 +121,19 @@ def __init__(self, *, gin_config: str, clang_path: str,
68
121
thread_count: The number of threads to use for concurrent compilation
69
122
and modelling.
70
123
corpus_path: The path to the corpus that modules will be compiled from.
124
+ copy_corpus_locally_path: If set, specifies the path that the corpus
125
+ should be copied to before utilizing the modules for evaluation.
126
+ Setting this to None signifies that no copying is desired.
71
127
"""
72
128
self ._clang_path = clang_path
73
129
self ._basic_block_trace_model_path = basic_block_trace_model_path
74
130
self ._thread_count = thread_count
131
+ self ._has_local_corpus = False
75
132
self ._corpus_path = corpus_path
133
+ if copy_corpus_locally_path is not None :
134
+ self ._copy_corpus (corpus_path , copy_corpus_locally_path )
135
+ self ._corpus_path = copy_corpus_locally_path
136
+ self ._has_local_corpus = True
76
137
77
138
gin .parse_config (gin_config )
78
139
self ._setup_base_policy ()
@@ -83,6 +144,8 @@ def __init__(self, *, gin_config: str, clang_path: str,
83
144
# have tempdirs wiped periodically.
84
145
def __del__ (self ):
85
146
shutil .rmtree (self ._tf_base_temp_dir )
147
+ if self ._has_local_corpus :
148
+ shutil .rmtree (self ._corpus_path )
86
149
87
150
def _compile_module (self , module_to_compile : corpus .ModuleSpec ,
88
151
output_directory : str , tflite_policy_path : str | None ):
0 commit comments