Skip to content

Commit 0e93371

Browse files
Add option to compile_modules pipeline to ingest raw BBs
This patch adds an option to the compile_modules pipeline to ingest basic blocks in a hex file (like in the BHive dataset) rather than compiling IR stored in parquet. This enables us to ingest a slightly processed CSV from the BHive dataset to then annotate and benchmark it with our pipeline to ensure the numbers match up reasonably well. Reviewers: virajbshah, ondrasej, orodley Reviewed By: ondrasej Pull Request: #339
1 parent e1a90cf commit 0e93371

File tree

3 files changed

+96
-35
lines changed

3 files changed

+96
-35
lines changed

gematria/datasets/pipelines/compile_modules.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
'parquet_folder',
3232
None,
3333
'The path to the folder containing parquet files',
34-
required=True,
3534
)
3635

3736
_OUTPUT_FILE = flags.DEFINE_string(
@@ -71,20 +70,27 @@
7170
' cannot be found.',
7271
)
7372

73+
_INPUT_HEX_BBS_FILE_PATTERN = flags.DEFINE_string(
74+
'input_hex_bbs_file_pattern',
75+
None,
76+
'The path to text files containing new line separated basic blocks.',
77+
)
78+
7479

7580
def main(argv) -> None:
7681
del argv # Unused.
7782

7883
beam_options = pipeline_options.PipelineOptions()
7984

8085
pipeline_constructor = compile_modules_lib.get_bbs(
81-
os.path.join(_PARQUET_FOLDER.value, '*.parquet'),
82-
_OUTPUT_FILE.value,
83-
_REMOVE_MEMORY_ACCESSING_INSTRUCTIONS.value,
84-
ANNOTATOR_MAPPING[_ANNOTATOR_TYPE.value],
85-
_MAX_ANNOTATION_ATTEMPTS.value,
86-
_OUTPUT_VOCAB_FILE.value,
87-
_SKIP_NO_LOOP_REGISTER.value,
86+
input_file_pattern=os.path.join(_PARQUET_FOLDER.value, '*.parquet'),
87+
output_file=_OUTPUT_FILE.value,
88+
remove_memory_accessing_instructions=_REMOVE_MEMORY_ACCESSING_INSTRUCTIONS.value,
89+
annotator_type=ANNOTATOR_MAPPING[_ANNOTATOR_TYPE.value],
90+
max_annotation_attempts=_MAX_ANNOTATION_ATTEMPTS.value,
91+
vocab_output_file=_OUTPUT_VOCAB_FILE.value,
92+
skip_no_loop_register=_SKIP_NO_LOOP_REGISTER.value,
93+
input_hex_bbs_file_pattern=_INPUT_HEX_BBS_FILE_PATTERN.value,
8894
)
8995

9096
with beam.Pipeline(options=beam_options) as pipeline:

gematria/datasets/pipelines/compile_modules_lib.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,14 @@ def process(self, bb_hex: str) -> Iterable[str]:
231231

232232

233233
def get_bbs(
234-
input_file_pattern: str,
234+
input_file_pattern: str | None,
235235
output_file: str,
236236
remove_memory_accessing_instructions: bool,
237237
annotator_type: bhive_to_exegesis.AnnotatorType,
238238
max_annotation_attempts: int,
239239
vocab_output_file: str,
240240
skip_no_loop_register: bool,
241+
input_hex_bbs_file_pattern: str | None,
241242
) -> Callable[[beam.Pipeline], None]:
242243
"""Creates a pipeline to process BBs from IR modules.
243244
@@ -247,7 +248,8 @@ def get_bbs(
247248
248249
Args:
249250
input_file_pattern: A grep-like pattern to use to search for the Parquet
250-
files to process.
251+
files to process. This cannot be used at the same time as
252+
input_hex_bbs_file_pattern.
251253
output_file: The output file pattern to use when writing the basic blocks
252254
to disk.
253255
remove_memory_accessing_instructions: Whether or not to remove memory
@@ -259,31 +261,45 @@ def get_bbs(
259261
vocab_output_file: The output pattern for the vocabulary file.
260262
skip_no_loop_register: Whether or not to omit basic blocks for which a free
261263
register to use as a loop counter cannot be found.
264+
input_hex_bbs_file_pattern: A grep-like file pattern to use to search for
265+
text files that contain basic blocks in hex format. This cannot be used
266+
at the same time as input_file_pattern.
262267
263268
Returns:
264269
A function that accepts a beam pipeline and adds on all the steps needed
265270
to process the input IR modules.
266271
"""
267272

268-
def pipeline(root: beam.Pipeline) -> None:
269-
parquet_data = root | 'Read' >> beam.io.ReadFromParquet(
270-
input_file_pattern, columns=['content']
271-
)
272-
module_data = parquet_data | 'Load' >> beam.Map(
273-
lambda parquet_row: parquet_row['content']
274-
)
275-
module_data_shuffled = module_data | 'Shuffle' >> beam.Reshuffle()
276-
optimized_modules = module_data_shuffled | 'Optimize' >> beam.ParDo(
277-
OptimizeModules(
278-
['default<O0>', 'default<O1>', 'default<O2>', 'default<O3>']
279-
)
280-
)
281-
lowered_modules = optimized_modules | 'Lower' >> beam.ParDo(
282-
LowerModulesAsm(['-O0', '-O1', '-O2', '-O3'])
283-
)
284-
bb_hex_values = lowered_modules | 'Get BBs' >> beam.ParDo(
285-
GetBBsFromModule()
273+
if (input_file_pattern is None) == (input_hex_bbs_file_pattern is None):
274+
raise ValueError(
275+
'Exactly one of input_file_pattern and input_hex_bbs_file_pattern must'
276+
' be set.'
286277
)
278+
279+
def pipeline(root: beam.Pipeline) -> None:
280+
if input_hex_bbs_file_pattern is not None:
281+
bb_hex_values = root | 'Read' >> beam.io.ReadFromText(
282+
input_hex_bbs_file_pattern
283+
)
284+
else:
285+
parquet_data = root | 'Read' >> beam.io.ReadFromParquet(
286+
input_file_pattern, columns=['content']
287+
)
288+
module_data = parquet_data | 'Load' >> beam.Map(
289+
lambda parquet_row: parquet_row['content']
290+
)
291+
module_data_shuffled = module_data | 'Shuffle' >> beam.Reshuffle()
292+
optimized_modules = module_data_shuffled | 'Optimize' >> beam.ParDo(
293+
OptimizeModules(
294+
['default<O0>', 'default<O1>', 'default<O2>', 'default<O3>']
295+
)
296+
)
297+
lowered_modules = optimized_modules | 'Lower' >> beam.ParDo(
298+
LowerModulesAsm(['-O0', '-O1', '-O2', '-O3'])
299+
)
300+
bb_hex_values = lowered_modules | 'Get BBs' >> beam.ParDo(
301+
GetBBsFromModule()
302+
)
287303
bb_hex_values_deduplicated = (
288304
bb_hex_values | 'Deduplicate' >> DeduplicateValues()
289305
)

gematria/datasets/pipelines/compile_modules_lib_test.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,52 @@ def test_get_bbs(self, annotator_type):
191191
)
192192

193193
pipeline_constructor = compile_modules_lib.get_bbs(
194-
test_parquet_file.full_path,
195-
output_file_pattern,
196-
False,
197-
annotator_type,
198-
50,
199-
vocab_output_file_pattern,
200-
False,
194+
input_file_pattern=test_parquet_file.full_path,
195+
output_file=output_file_pattern,
196+
remove_memory_accessing_instructions=False,
197+
annotator_type=annotator_type,
198+
max_annotation_attempts=50,
199+
vocab_output_file=vocab_output_file_pattern,
200+
skip_no_loop_register=False,
201+
input_hex_bbs_file_pattern=None,
202+
)
203+
204+
with test_pipeline.TestPipeline() as pipeline_under_test:
205+
pipeline_constructor(pipeline_under_test)
206+
207+
block_hex_values = []
208+
for annotated_block in tfrecord.read_protos(
209+
[output_file_pattern + '-00000-of-00001'],
210+
execution_annotation_pb2.BlockWithExecutionAnnotations,
211+
):
212+
block_hex_values.append(annotated_block.block_hex)
213+
214+
self.assertLen(block_hex_values, 2)
215+
self.assertContainsSubset(['B801000000', 'B802000000'], block_hex_values)
216+
217+
with open(
218+
vocab_output_file_pattern + '-00000-of-00001'
219+
) as vocab_file_handle:
220+
vocab_tokens = [token.strip() for token in vocab_file_handle.readlines()]
221+
222+
self.assertCountEqual(['_D_', '_IMMEDIATE_', 'MOV', 'EAX'], vocab_tokens)
223+
224+
def test_get_bbs_hex_file(self):
225+
test_bb_file = self.create_tempfile()
226+
output_file_dir = self.create_tempdir()
227+
output_file_pattern = os.path.join(output_file_dir, 'bbs')
228+
vocab_output_file_pattern = os.path.join(output_file_dir, 'bbvocab')
229+
test_bb_file.write_text('B801000000\nB802000000\n')
230+
231+
pipeline_constructor = compile_modules_lib.get_bbs(
232+
input_file_pattern=None,
233+
output_file=output_file_pattern,
234+
remove_memory_accessing_instructions=False,
235+
annotator_type=bhive_to_exegesis.AnnotatorType.exegesis,
236+
max_annotation_attempts=50,
237+
vocab_output_file=vocab_output_file_pattern,
238+
skip_no_loop_register=False,
239+
input_hex_bbs_file_pattern=test_bb_file.full_path,
201240
)
202241

203242
with test_pipeline.TestPipeline() as pipeline_under_test:

0 commit comments

Comments
 (0)