@@ -60,8 +60,11 @@ def _setup_base_policy(self):
60
60
saver .save (self ._tf_base_temp_dir )
61
61
self ._tf_base_policy_path = os .path .join (self ._tf_base_temp_dir , "policy" )
62
62
63
- def _copy_corpus (self , corpus_path : str ,
64
- copy_corpus_locally_path : str | None ) -> None :
63
+ # TODO(issues/471): aux_file_replacement_flags should be refactored out of
64
+ # regalloc_trace_worker as it will need to be used in other places
65
+ # eventually.
66
+ def _copy_corpus (self , corpus_path : str , copy_corpus_locally_path : str | None ,
67
+ aux_file_replacement_flags : dict [str , str ]) -> None :
65
68
"""Makes a local copy of the corpus if requested.
66
69
67
70
This function makes a local copy of the corpus by copying the remote
@@ -70,6 +73,8 @@ def _copy_corpus(self, corpus_path: str,
70
73
Args:
71
74
corpus_path: The path to the remote corpus.
72
75
copy_corpus_locally: The local path to copy the corpus to.
76
+ aux_file_replacement_flags: Additional files to copy over that are
77
+ passed in through flags, like profiles.
73
78
"""
74
79
# We use the tensorflow APIs below rather than the standard Python file
75
80
# APIs for compatibility with more filesystems.
@@ -97,18 +102,30 @@ def _copy_corpus(self, corpus_path: str,
97
102
copy_thread_pool .submit (_make_dirs_and_copy , current_path ,
98
103
new_path ))
99
104
105
+ if aux_file_replacement_flags is not None :
106
+ for flag_name in aux_file_replacement_flags :
107
+ aux_replacement_file = aux_file_replacement_flags [flag_name ]
108
+ new_path = os .path .join (copy_corpus_locally_path ,
109
+ os .path .basename (aux_replacement_file ))
110
+ copy_futures .append (
111
+ copy_thread_pool .submit (_make_dirs_and_copy , aux_replacement_file ,
112
+ new_path ))
113
+
100
114
for copy_future in copy_futures :
101
115
if copy_future .exception () is not None :
102
116
raise copy_future .exception ()
103
117
104
- def __init__ (self ,
105
- * ,
106
- gin_config : str ,
107
- clang_path : str ,
108
- basic_block_trace_model_path : str ,
109
- thread_count : int ,
110
- corpus_path : str ,
111
- copy_corpus_locally_path : str | None = None ):
118
+ def __init__ (
119
+ self ,
120
+ * ,
121
+ gin_config : str ,
122
+ clang_path : str ,
123
+ basic_block_trace_model_path : str ,
124
+ thread_count : int ,
125
+ corpus_path : str ,
126
+ copy_corpus_locally_path : str | None = None ,
127
+ aux_file_replacement_flags : dict [str , str ] | None = None ,
128
+ ):
112
129
"""Initializes the RegallocTraceWorker class.
113
130
114
131
Args:
@@ -124,17 +141,39 @@ def __init__(self,
124
141
copy_corpus_locally_path: If set, specifies the path that the corpus
125
142
should be copied to before utilizing the modules for evaluation.
126
143
Setting this to None signifies that no copying is desired.
144
+ aux_file_replacement_flags: A dictionary mapping sentinel values intended
145
+ to be set using the corpus replace_flags feature to actual file paths
146
+ local to the worker. This is intended to be used in distributed
147
+ training setups where training corpora and auxiliary files need to be
148
+ copied locally before being compiled.
127
149
"""
128
150
self ._clang_path = clang_path
129
151
self ._basic_block_trace_model_path = basic_block_trace_model_path
130
152
self ._thread_count = thread_count
153
+
131
154
self ._has_local_corpus = False
132
155
self ._corpus_path = corpus_path
133
156
if copy_corpus_locally_path is not None :
134
- self ._copy_corpus (corpus_path , copy_corpus_locally_path )
157
+ self ._copy_corpus (corpus_path , copy_corpus_locally_path ,
158
+ aux_file_replacement_flags )
135
159
self ._corpus_path = copy_corpus_locally_path
136
160
self ._has_local_corpus = True
137
161
162
+ if (copy_corpus_locally_path is None and
163
+ aux_file_replacement_flags is not None ):
164
+ raise ValueError (
165
+ "additional_replacement_flags is incompatible with fully local "
166
+ "corpus setups. Please directly replace the flag with the correct "
167
+ "value." )
168
+ self ._aux_file_replacement_flags = aux_file_replacement_flags
169
+ self ._aux_file_replacement_context = {}
170
+ if aux_file_replacement_flags is not None :
171
+ for flag_name in self ._aux_file_replacement_flags :
172
+ self ._aux_file_replacement_context [flag_name ] = os .path .join (
173
+ self ._corpus_path ,
174
+ os .path .basename (self ._aux_file_replacement_flags [flag_name ]),
175
+ )
176
+
138
177
gin .parse_config (gin_config )
139
178
self ._setup_base_policy ()
140
179
@@ -156,7 +195,7 @@ def _compile_module(self, module_to_compile: corpus.ModuleSpec,
156
195
# using ThinLTO, we will just never end up replacing anything.
157
196
os .path .join (self ._corpus_path , module_to_compile .name ) + ".thinlto.bc" )
158
197
command_vector .extend ([
159
- option .format (context = context )
198
+ option .format (context = context , ** self . _aux_file_replacement_context )
160
199
for option in module_to_compile .command_line
161
200
])
162
201
0 commit comments