Skip to content

Commit 8e9a0fc

Browse files
authored
[IR2Vec] Add triplet generation utility script for vocabulary training (#149215)
Added a Python utility script for generating IR2Vec triplets and updated documentation to reference it. The script generates triplets in a form suitable for training the vocabulary. (Tracking issues - #141817, #141834; closes - #141834)
1 parent a14659a commit 8e9a0fc

File tree

2 files changed

+307
-0
lines changed

2 files changed

+307
-0
lines changed

llvm/docs/CommandGuide/llvm-ir2vec.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ embedding training (see
5050
<https://github.com/thunlp/OpenKE/tree/OpenKE-PyTorch?tab=readme-ov-file#data-format>
5151
for details).
5252

53+
See `llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py` for more details on how
54+
these two modes are used to generate the triplets and entity mappings.
55+
5356
Triplet Generation Mode
5457
~~~~~~~~~~~~~~~~~~~~~~~
5558

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
"""IR2Vec Triplet Generator
5+
6+
Generates IR2Vec triplets by applying random optimization levels to LLVM IR files
7+
and extracting triplets using llvm-ir2vec. Automatically generates preprocessed
8+
files: entity2id.txt, relation2id.txt, and train2id.txt.
9+
10+
Usage:
11+
python generateTriplets.py <llvm_build_dir> <num_optimizations> <ll_file_list> <output_dir>
12+
"""
13+
14+
import argparse
15+
import logging
16+
import os
17+
import random
18+
import subprocess
19+
from concurrent.futures import ThreadPoolExecutor, as_completed
20+
from pathlib import Path
21+
from typing import List, Set, Tuple
22+
23+
# Configuration
24+
OPT_LEVELS = ["O0", "O1", "O2", "O3", "Os", "Oz"]
25+
DEFAULT_MAX_WORKERS = 100
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
# TODO: Change this to a dataclass with slots
31+
# when Python 3.10+ is the minimum version
32+
# https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass
33+
class TripletResult:
34+
"""Result from processing a single LLVM IR file"""
35+
36+
__slots__ = ["triplets", "max_relation"]
37+
38+
def __init__(self, triplets: Set[str], max_relation: int):
39+
self.triplets = triplets
40+
self.max_relation = max_relation
41+
42+
43+
class IR2VecTripletGenerator:
44+
"""Main class for generating IR2Vec triplets"""
45+
46+
def __init__(
47+
self,
48+
llvm_build_dir: Path,
49+
num_optimizations: int,
50+
output_dir: Path,
51+
max_workers: int = DEFAULT_MAX_WORKERS,
52+
):
53+
self.llvm_build_dir = llvm_build_dir
54+
self.num_optimizations = num_optimizations
55+
self.output_dir = output_dir
56+
self.max_workers = max_workers
57+
58+
# Tool paths
59+
self.opt_binary = os.path.join(llvm_build_dir, "bin", "opt")
60+
self.ir2vec_binary = os.path.join(llvm_build_dir, "bin", "llvm-ir2vec")
61+
62+
self._validate_setup()
63+
64+
# Create output directory if it doesn't exist
65+
self.output_dir.mkdir(parents=True, exist_ok=True)
66+
67+
def _validate_setup(self):
68+
"""Validate that all required tools and paths exist"""
69+
if not self.llvm_build_dir.exists():
70+
raise FileNotFoundError(
71+
f"LLVM build directory not found: {self.llvm_build_dir}"
72+
)
73+
74+
if not os.path.isfile(self.opt_binary) or not os.access(
75+
self.opt_binary, os.X_OK
76+
):
77+
raise FileNotFoundError(
78+
f"opt binary not found or not executable: {self.opt_binary}"
79+
)
80+
81+
if not os.path.isfile(self.ir2vec_binary) or not os.access(
82+
self.ir2vec_binary, os.X_OK
83+
):
84+
raise FileNotFoundError(
85+
f"llvm-ir2vec binary not found or not executable: {self.ir2vec_binary}"
86+
)
87+
88+
if not (1 <= self.num_optimizations <= len(OPT_LEVELS)):
89+
raise ValueError(
90+
f"Number of optimizations must be between 1-{len(OPT_LEVELS)}"
91+
)
92+
93+
def _select_optimization_levels(self) -> List[str]:
94+
"""Select unique random optimization levels"""
95+
return random.sample(OPT_LEVELS, self.num_optimizations)
96+
97+
def _process_single_file(self, input_file: Path) -> TripletResult:
98+
"""Process a single LLVM IR file with multiple optimization levels"""
99+
all_triplets = set()
100+
max_relation = 1
101+
opt_levels = self._select_optimization_levels()
102+
103+
for opt_level in opt_levels:
104+
triplets, file_max_relation = self._run_pipeline(input_file, opt_level)
105+
if triplets:
106+
all_triplets.update(triplets)
107+
max_relation = max(max_relation, file_max_relation)
108+
logger.debug(
109+
f"Generated {len(triplets)} triplets for {input_file} with {opt_level}"
110+
)
111+
112+
return TripletResult(all_triplets, max_relation)
113+
114+
def _run_pipeline(self, input_file: Path, opt_level: str) -> Tuple[Set[str], int]:
115+
"""Run opt | llvm-ir2vec pipeline using subprocess pipes."""
116+
try:
117+
# Run opt first
118+
opt_proc = subprocess.Popen(
119+
[self.opt_binary, f"-{opt_level}", str(input_file), "-o", "-"],
120+
stdout=subprocess.PIPE,
121+
stderr=subprocess.PIPE,
122+
text=True,
123+
)
124+
125+
# Run llvm-ir2vec with opt's output as input
126+
ir2vec_proc = subprocess.Popen(
127+
[self.ir2vec_binary, "--mode=triplets", "-", "-o", "-"],
128+
stdin=opt_proc.stdout,
129+
stdout=subprocess.PIPE,
130+
stderr=subprocess.PIPE,
131+
text=True,
132+
)
133+
134+
opt_proc.stdout.close()
135+
stdout, _ = ir2vec_proc.communicate()
136+
opt_proc.wait()
137+
138+
# Check if either process failed
139+
if opt_proc.returncode != 0 or ir2vec_proc.returncode != 0:
140+
return set(), 1
141+
142+
return self._parse_triplet_output(stdout)
143+
except (subprocess.SubprocessError, OSError):
144+
return set(), 1
145+
146+
def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]:
147+
"""Parse triplet output and extract max relation"""
148+
if not output.strip():
149+
return set(), 1
150+
151+
lines = output.strip().split("\n")
152+
max_relation = 1
153+
154+
# Extract max relation from metadata line
155+
if lines and lines[0].startswith("MAX_RELATION="):
156+
max_relation = int(lines[0].split("=")[1])
157+
lines = lines[1:]
158+
159+
# Remove duplicate triplets by converting to a set
160+
return set(lines), max_relation
161+
162+
def generate_triplets(self, file_list: Path) -> None:
163+
"""Main method to generate triplets from a list of LLVM IR files"""
164+
input_files = self._read_file_list(file_list)
165+
logger.info(
166+
f"Processing {len(input_files)} files with {self.num_optimizations} "
167+
f"optimization levels using {self.max_workers} workers"
168+
)
169+
170+
all_triplets = set()
171+
global_max_relation = 1
172+
173+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
174+
future_to_file = {
175+
executor.submit(self._process_single_file, file): file
176+
for file in input_files
177+
}
178+
179+
for future in as_completed(future_to_file):
180+
try:
181+
result = future.result()
182+
all_triplets.update(result.triplets)
183+
global_max_relation = max(global_max_relation, result.max_relation)
184+
except (subprocess.SubprocessError, OSError, ValueError) as e:
185+
file_path = future_to_file[future]
186+
logger.error(f"Error processing {file_path}: {e}")
187+
188+
self._generate_output_files(all_triplets, global_max_relation)
189+
logger.info("Processing completed successfully")
190+
191+
def _read_file_list(self, file_list: Path) -> List[Path]:
192+
"""Read and validate the list of input files"""
193+
input_files = []
194+
with open(file_list, "r") as f:
195+
for line_num, line in enumerate(f, 1):
196+
if line := line.strip():
197+
file_path = Path(line)
198+
if file_path.exists():
199+
input_files.append(file_path)
200+
else:
201+
logger.warning(f"File not found (line {line_num}): {file_path}")
202+
203+
if not input_files:
204+
raise ValueError("No valid input files found")
205+
return input_files
206+
207+
def _generate_output_files(self, all_triplets: Set[str], max_relation: int) -> None:
208+
"""Generate the final output files"""
209+
logger.info(f"Generating output files with {len(all_triplets)} unique triplets")
210+
211+
# Write all output files -- train2id.txt, entity2id.txt, relation2id.txt
212+
train2id_file = os.path.join(self.output_dir, "train2id.txt")
213+
entity2id_file = os.path.join(self.output_dir, "entity2id.txt")
214+
relation2id_file = os.path.join(self.output_dir, "relation2id.txt")
215+
216+
with open(train2id_file, "w") as f:
217+
f.write(f"{len(all_triplets)}\n")
218+
f.writelines(f"{triplet}\n" for triplet in all_triplets)
219+
220+
self._generate_entity2id(entity2id_file)
221+
self._generate_relation2id(relation2id_file, max_relation)
222+
223+
def _generate_entity2id(self, output_file: Path) -> None:
224+
"""Generate entity2id.txt using llvm-ir2vec"""
225+
subprocess.run(
226+
[str(self.ir2vec_binary), "--mode=entities", "-o", str(output_file)],
227+
check=True,
228+
capture_output=True,
229+
)
230+
231+
def _generate_relation2id(self, output_file: Path, max_relation: int) -> None:
232+
"""Generate relation2id.txt from max relation"""
233+
max_relation = max(max_relation, 1) # At least Type and Next relations
234+
num_relations = max_relation + 1
235+
236+
with open(output_file, "w") as f:
237+
f.write(f"{num_relations}\n")
238+
f.write("Type\t0\n")
239+
f.write("Next\t1\n")
240+
f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations))
241+
242+
243+
def main():
244+
"""Main entry point"""
245+
parser = argparse.ArgumentParser(
246+
description="Generate IR2Vec triplets from LLVM IR files",
247+
formatter_class=argparse.RawDescriptionHelpFormatter,
248+
)
249+
250+
parser.add_argument(
251+
"llvm_build_dir", type=Path, help="Path to LLVM build directory"
252+
)
253+
parser.add_argument(
254+
"num_optimizations",
255+
type=int,
256+
help="Number of optimization levels to apply (1-6)",
257+
)
258+
parser.add_argument(
259+
"ll_file_list",
260+
type=Path,
261+
help="File containing list of LLVM IR files to process",
262+
)
263+
parser.add_argument(
264+
"output_dir", type=Path, help="Output directory for generated files"
265+
)
266+
parser.add_argument(
267+
"-j",
268+
"--max-workers",
269+
type=int,
270+
default=DEFAULT_MAX_WORKERS,
271+
help=f"Maximum number of parallel workers (default: {DEFAULT_MAX_WORKERS})",
272+
)
273+
parser.add_argument(
274+
"-v", "--verbose", action="store_true", help="Enable debug logging"
275+
)
276+
parser.add_argument(
277+
"-q", "--quiet", action="store_true", help="Suppress all output except errors"
278+
)
279+
280+
args = parser.parse_args()
281+
282+
# Configure logging
283+
level = (
284+
logging.ERROR
285+
if args.quiet
286+
else (logging.DEBUG if args.verbose else logging.INFO)
287+
)
288+
logging.basicConfig(
289+
level=level,
290+
format="[%(asctime)s] %(levelname)s: %(message)s",
291+
datefmt="%H:%M:%S",
292+
)
293+
294+
generator = IR2VecTripletGenerator(
295+
args.llvm_build_dir,
296+
args.num_optimizations,
297+
args.output_dir,
298+
args.max_workers,
299+
)
300+
generator.generate_triplets(args.ll_file_list)
301+
302+
303+
if __name__ == "__main__":
304+
main()

0 commit comments

Comments
 (0)