Skip to content

Commit c97256d

Browse files
authored
[MIR2Vec] Add MIR support to triplet generator script (llvm#164332)
Add support for MIR (Machine IR) triplet generation to the triplet gen script.
1 parent 910cf51 commit c97256d

File tree

1 file changed

+131
-34
lines changed

1 file changed

+131
-34
lines changed

llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py

Lines changed: 131 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4-
"""IR2Vec Triplet Generator
4+
"""IR2Vec/MIR2Vec Triplet Generator
55
6-
Generates IR2Vec triplets by applying random optimization levels to LLVM IR files
7-
and extracting triplets using llvm-ir2vec. Automatically generates preprocessed
8-
files: entity2id.txt, relation2id.txt, and train2id.txt.
6+
Generates IR2Vec or MIR2Vec triplets by applying random optimization levels to
7+
LLVM IR files (or processing MIR files) and extracting triplets using llvm-ir2vec.
8+
Automatically generates preprocessed files (entity2id.txt, relation2id.txt, and
9+
train2id.txt) necessary for training IR2Vec or MIR2Vec vocabularies.
910
1011
Usage:
11-
python generateTriplets.py <llvm_build_dir> <num_optimizations> <ll_file_list> <output_dir>
12+
For LLVM IR:
13+
python generateTriplets.py <llvm_build_dir> <num_optimizations> <ll_file_list> <output_dir>
14+
15+
For Machine IR:
16+
python generateTriplets.py --mode=mir <llvm_build_dir> <mir_file_list> <output_dir>
1217
"""
1318

1419
import argparse
@@ -41,19 +46,21 @@ def __init__(self, triplets: Set[str], max_relation: int):
4146

4247

4348
class IR2VecTripletGenerator:
44-
"""Main class for generating IR2Vec triplets"""
49+
"""Main class for generating IR2Vec or MIR2Vec triplets"""
4550

4651
def __init__(
4752
self,
4853
llvm_build_dir: Path,
4954
num_optimizations: int,
5055
output_dir: Path,
5156
max_workers: int = DEFAULT_MAX_WORKERS,
57+
mode: str = "llvm",
5258
):
5359
self.llvm_build_dir = llvm_build_dir
5460
self.num_optimizations = num_optimizations
5561
self.output_dir = output_dir
5662
self.max_workers = max_workers
63+
self.mode = mode # "llvm" or "mir"
5764

5865
# Tool paths
5966
self.opt_binary = os.path.join(llvm_build_dir, "bin", "opt")
@@ -85,7 +92,11 @@ def _validate_setup(self):
8592
f"llvm-ir2vec binary not found or not executable: {self.ir2vec_binary}"
8693
)
8794

88-
if not (1 <= self.num_optimizations <= len(OPT_LEVELS)):
95+
if self.mode not in ["llvm", "mir"]:
96+
raise ValueError(f"Mode must be 'llvm' or 'mir', got: {self.mode}")
97+
98+
# For LLVM IR mode, validate optimization count
99+
if self.mode == "llvm" and not (1 <= self.num_optimizations <= len(OPT_LEVELS)):
89100
raise ValueError(
90101
f"Number of optimizations must be between 1-{len(OPT_LEVELS)}"
91102
)
@@ -95,19 +106,28 @@ def _select_optimization_levels(self) -> List[str]:
95106
return random.sample(OPT_LEVELS, self.num_optimizations)
96107

97108
def _process_single_file(self, input_file: Path) -> TripletResult:
98-
"""Process a single LLVM IR file with multiple optimization levels"""
109+
"""Process a single LLVM IR or MIR file"""
99110
all_triplets = set()
100111
max_relation = 1
101-
opt_levels = self._select_optimization_levels()
102112

103-
for opt_level in opt_levels:
104-
triplets, file_max_relation = self._run_pipeline(input_file, opt_level)
113+
if self.mode == "mir":
114+
# For MIR files, process directly without optimization
115+
triplets, file_max_relation = self._run_mir_pipeline(input_file)
105116
if triplets:
106117
all_triplets.update(triplets)
107118
max_relation = max(max_relation, file_max_relation)
108-
logger.debug(
109-
f"Generated {len(triplets)} triplets for {input_file} with {opt_level}"
110-
)
119+
logger.debug(f"Generated {len(triplets)} triplets for {input_file}")
120+
else:
121+
# For LLVM IR files, apply multiple optimization levels
122+
opt_levels = self._select_optimization_levels()
123+
for opt_level in opt_levels:
124+
triplets, file_max_relation = self._run_pipeline(input_file, opt_level)
125+
if triplets:
126+
all_triplets.update(triplets)
127+
max_relation = max(max_relation, file_max_relation)
128+
logger.debug(
129+
f"Generated {len(triplets)} triplets for {input_file} with {opt_level}"
130+
)
111131

112132
return TripletResult(all_triplets, max_relation)
113133

@@ -124,7 +144,7 @@ def _run_pipeline(self, input_file: Path, opt_level: str) -> Tuple[Set[str], int
124144

125145
# Run llvm-ir2vec with opt's output as input
126146
ir2vec_proc = subprocess.Popen(
127-
[self.ir2vec_binary, "triplets", "-", "-o", "-"],
147+
[self.ir2vec_binary, "triplets", "--mode=llvm", "-", "-o", "-"],
128148
stdin=opt_proc.stdout,
129149
stdout=subprocess.PIPE,
130150
stderr=subprocess.PIPE,
@@ -143,6 +163,32 @@ def _run_pipeline(self, input_file: Path, opt_level: str) -> Tuple[Set[str], int
143163
except (subprocess.SubprocessError, OSError):
144164
return set(), 1
145165

166+
def _run_mir_pipeline(self, input_file: Path) -> Tuple[Set[str], int]:
167+
"""Run llvm-ir2vec pipeline for MIR files."""
168+
try:
169+
# Run llvm-ir2vec directly on MIR file
170+
result = subprocess.run(
171+
[
172+
self.ir2vec_binary,
173+
"triplets",
174+
"--mode=mir",
175+
str(input_file),
176+
"-o",
177+
"-",
178+
],
179+
stdout=subprocess.PIPE,
180+
stderr=subprocess.PIPE,
181+
text=True,
182+
check=False,
183+
)
184+
185+
if result.returncode != 0:
186+
return set(), 1
187+
188+
return self._parse_triplet_output(result.stdout)
189+
except (subprocess.SubprocessError, OSError):
190+
return set(), 1
191+
146192
def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]:
147193
"""Parse triplet output and extract max relation"""
148194
if not output.strip():
@@ -160,12 +206,21 @@ def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]:
160206
return set(lines), max_relation
161207

162208
def generate_triplets(self, file_list: Path) -> None:
163-
"""Main method to generate triplets from a list of LLVM IR files"""
209+
"""Main method to generate triplets from a list of LLVM IR or MIR files"""
210+
# Store file_list_path for later use in entity generation
211+
self.file_list_path = file_list
212+
164213
input_files = self._read_file_list(file_list)
165-
logger.info(
166-
f"Processing {len(input_files)} files with {self.num_optimizations} "
167-
f"optimization levels using {self.max_workers} workers"
168-
)
214+
215+
if self.mode == "mir":
216+
logger.info(
217+
f"Processing {len(input_files)} MIR files using {self.max_workers} workers"
218+
)
219+
else:
220+
logger.info(
221+
f"Processing {len(input_files)} files with {self.num_optimizations} "
222+
f"optimization levels using {self.max_workers} workers"
223+
)
169224

170225
all_triplets = set()
171226
global_max_relation = 1
@@ -222,28 +277,60 @@ def _generate_output_files(self, all_triplets: Set[str], max_relation: int) -> N
222277

223278
def _generate_entity2id(self, output_file: Path) -> None:
224279
"""Generate entity2id.txt using llvm-ir2vec"""
225-
subprocess.run(
226-
[str(self.ir2vec_binary), "entities", "-o", str(output_file)],
227-
check=True,
228-
capture_output=True,
229-
)
280+
if self.mode == "mir":
281+
# For MIR mode, we need to provide a sample MIR file to determine target
282+
# Use the first file from the processed list
283+
input_files = self._read_file_list(self.file_list_path)
284+
if not input_files:
285+
raise ValueError("No input files available for entity generation")
286+
287+
subprocess.run(
288+
[
289+
str(self.ir2vec_binary),
290+
"entities",
291+
"--mode=mir",
292+
str(input_files[0]),
293+
"-o",
294+
str(output_file),
295+
],
296+
check=True,
297+
capture_output=True,
298+
)
299+
else:
300+
subprocess.run(
301+
[
302+
str(self.ir2vec_binary),
303+
"entities",
304+
"--mode=llvm",
305+
"-o",
306+
str(output_file),
307+
],
308+
check=True,
309+
capture_output=True,
310+
)
230311

231312
def _generate_relation2id(self, output_file: Path, max_relation: int) -> None:
232313
"""Generate relation2id.txt from max relation"""
233-
max_relation = max(max_relation, 1) # At least Type and Next relations
314+
max_relation = max(max_relation, 1) # At least Next relation
234315
num_relations = max_relation + 1
235316

236317
with open(output_file, "w") as f:
237318
f.write(f"{num_relations}\n")
238-
f.write("Type\t0\n")
239-
f.write("Next\t1\n")
240-
f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations))
319+
if self.mode == "llvm":
320+
# LLVM IR has Type relation at 0
321+
f.write("Type\t0\n")
322+
f.write("Next\t1\n")
323+
f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations))
324+
else:
325+
# MIR doesn't have Type relation, starts with Next at 0
326+
f.write("Next\t0\n")
327+
f.writelines(f"Arg{i-1}\t{i}\n" for i in range(1, num_relations))
241328

242329

243330
def main():
244331
"""Main entry point"""
245332
parser = argparse.ArgumentParser(
246-
description="Generate IR2Vec triplets from LLVM IR files",
333+
description="Generate IR2Vec or MIR2Vec triplets from LLVM IR or Machine IR files",
247334
formatter_class=argparse.RawDescriptionHelpFormatter,
248335
)
249336

@@ -253,16 +340,25 @@ def main():
253340
parser.add_argument(
254341
"num_optimizations",
255342
type=int,
256-
help="Number of optimization levels to apply (1-6)",
343+
nargs="?",
344+
default=1,
345+
help="Number of optimization levels to apply (1-6) for LLVM IR mode",
257346
)
258347
parser.add_argument(
259-
"ll_file_list",
348+
"input_file_list",
260349
type=Path,
261-
help="File containing list of LLVM IR files to process",
350+
help="File containing list of LLVM IR or MIR files to process",
262351
)
263352
parser.add_argument(
264353
"output_dir", type=Path, help="Output directory for generated files"
265354
)
355+
parser.add_argument(
356+
"--mode",
357+
type=str,
358+
choices=["llvm", "mir"],
359+
default="llvm",
360+
help="Operation mode: 'llvm' for LLVM IR (default) or 'mir' for Machine IR",
361+
)
266362
parser.add_argument(
267363
"-j",
268364
"--max-workers",
@@ -296,8 +392,9 @@ def main():
296392
args.num_optimizations,
297393
args.output_dir,
298394
args.max_workers,
395+
args.mode,
299396
)
300-
generator.generate_triplets(args.ll_file_list)
397+
generator.generate_triplets(args.input_file_list)
301398

302399

303400
if __name__ == "__main__":

0 commit comments

Comments
 (0)