Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 131 additions & 34 deletions llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""IR2Vec Triplet Generator
"""IR2Vec/MIR2Vec Triplet Generator

Generates IR2Vec triplets by applying random optimization levels to LLVM IR files
and extracting triplets using llvm-ir2vec. Automatically generates preprocessed
files: entity2id.txt, relation2id.txt, and train2id.txt.
Generates IR2Vec or MIR2Vec triplets by applying random optimization levels to
LLVM IR files (or processing MIR files) and extracting triplets using llvm-ir2vec.
Automatically generates preprocessed files (entity2id.txt, relation2id.txt, and
train2id.txt) necessary for training IR2Vec or MIR2Vec vocabularies.

Usage:
python generateTriplets.py <llvm_build_dir> <num_optimizations> <ll_file_list> <output_dir>
For LLVM IR:
python generateTriplets.py <llvm_build_dir> <num_optimizations> <ll_file_list> <output_dir>

For Machine IR:
python generateTriplets.py --mode=mir <llvm_build_dir> <mir_file_list> <output_dir>
"""

import argparse
Expand Down Expand Up @@ -41,19 +46,21 @@ def __init__(self, triplets: Set[str], max_relation: int):


class IR2VecTripletGenerator:
"""Main class for generating IR2Vec triplets"""
"""Main class for generating IR2Vec or MIR2Vec triplets"""

def __init__(
self,
llvm_build_dir: Path,
num_optimizations: int,
output_dir: Path,
max_workers: int = DEFAULT_MAX_WORKERS,
mode: str = "llvm",
):
self.llvm_build_dir = llvm_build_dir
self.num_optimizations = num_optimizations
self.output_dir = output_dir
self.max_workers = max_workers
self.mode = mode # "llvm" or "mir"

# Tool paths
self.opt_binary = os.path.join(llvm_build_dir, "bin", "opt")
Expand Down Expand Up @@ -85,7 +92,11 @@ def _validate_setup(self):
f"llvm-ir2vec binary not found or not executable: {self.ir2vec_binary}"
)

if not (1 <= self.num_optimizations <= len(OPT_LEVELS)):
if self.mode not in ["llvm", "mir"]:
raise ValueError(f"Mode must be 'llvm' or 'mir', got: {self.mode}")

# For LLVM IR mode, validate optimization count
if self.mode == "llvm" and not (1 <= self.num_optimizations <= len(OPT_LEVELS)):
raise ValueError(
f"Number of optimizations must be between 1-{len(OPT_LEVELS)}"
)
Expand All @@ -95,19 +106,28 @@ def _select_optimization_levels(self) -> List[str]:
return random.sample(OPT_LEVELS, self.num_optimizations)

def _process_single_file(self, input_file: Path) -> TripletResult:
"""Process a single LLVM IR file with multiple optimization levels"""
"""Process a single LLVM IR or MIR file"""
all_triplets = set()
max_relation = 1
opt_levels = self._select_optimization_levels()

for opt_level in opt_levels:
triplets, file_max_relation = self._run_pipeline(input_file, opt_level)
if self.mode == "mir":
# For MIR files, process directly without optimization
triplets, file_max_relation = self._run_mir_pipeline(input_file)
if triplets:
all_triplets.update(triplets)
max_relation = max(max_relation, file_max_relation)
logger.debug(
f"Generated {len(triplets)} triplets for {input_file} with {opt_level}"
)
logger.debug(f"Generated {len(triplets)} triplets for {input_file}")
else:
# For LLVM IR files, apply multiple optimization levels
opt_levels = self._select_optimization_levels()
for opt_level in opt_levels:
triplets, file_max_relation = self._run_pipeline(input_file, opt_level)
if triplets:
all_triplets.update(triplets)
max_relation = max(max_relation, file_max_relation)
logger.debug(
f"Generated {len(triplets)} triplets for {input_file} with {opt_level}"
)

return TripletResult(all_triplets, max_relation)

Expand All @@ -124,7 +144,7 @@ def _run_pipeline(self, input_file: Path, opt_level: str) -> Tuple[Set[str], int

# Run llvm-ir2vec with opt's output as input
ir2vec_proc = subprocess.Popen(
[self.ir2vec_binary, "triplets", "-", "-o", "-"],
[self.ir2vec_binary, "triplets", "--mode=llvm", "-", "-o", "-"],
stdin=opt_proc.stdout,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
Expand All @@ -143,6 +163,32 @@ def _run_pipeline(self, input_file: Path, opt_level: str) -> Tuple[Set[str], int
except (subprocess.SubprocessError, OSError):
return set(), 1

def _run_mir_pipeline(self, input_file: Path) -> Tuple[Set[str], int]:
"""Run llvm-ir2vec pipeline for MIR files."""
try:
# Run llvm-ir2vec directly on MIR file
result = subprocess.run(
[
self.ir2vec_binary,
"triplets",
"--mode=mir",
str(input_file),
"-o",
"-",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=False,
)

if result.returncode != 0:
return set(), 1

return self._parse_triplet_output(result.stdout)
except (subprocess.SubprocessError, OSError):
return set(), 1

def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]:
"""Parse triplet output and extract max relation"""
if not output.strip():
Expand All @@ -160,12 +206,21 @@ def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]:
return set(lines), max_relation

def generate_triplets(self, file_list: Path) -> None:
"""Main method to generate triplets from a list of LLVM IR files"""
"""Main method to generate triplets from a list of LLVM IR or MIR files"""
# Store file_list_path for later use in entity generation
self.file_list_path = file_list

input_files = self._read_file_list(file_list)
logger.info(
f"Processing {len(input_files)} files with {self.num_optimizations} "
f"optimization levels using {self.max_workers} workers"
)

if self.mode == "mir":
logger.info(
f"Processing {len(input_files)} MIR files using {self.max_workers} workers"
)
else:
logger.info(
f"Processing {len(input_files)} files with {self.num_optimizations} "
f"optimization levels using {self.max_workers} workers"
)

all_triplets = set()
global_max_relation = 1
Expand Down Expand Up @@ -222,28 +277,60 @@ def _generate_output_files(self, all_triplets: Set[str], max_relation: int) -> N

def _generate_entity2id(self, output_file: Path) -> None:
"""Generate entity2id.txt using llvm-ir2vec"""
subprocess.run(
[str(self.ir2vec_binary), "entities", "-o", str(output_file)],
check=True,
capture_output=True,
)
if self.mode == "mir":
# For MIR mode, we need to provide a sample MIR file to determine target
# Use the first file from the processed list
input_files = self._read_file_list(self.file_list_path)
if not input_files:
raise ValueError("No input files available for entity generation")

subprocess.run(
[
str(self.ir2vec_binary),
"entities",
"--mode=mir",
str(input_files[0]),
"-o",
str(output_file),
],
check=True,
capture_output=True,
)
else:
subprocess.run(
[
str(self.ir2vec_binary),
"entities",
"--mode=llvm",
"-o",
str(output_file),
],
check=True,
capture_output=True,
)

def _generate_relation2id(self, output_file: Path, max_relation: int) -> None:
"""Generate relation2id.txt from max relation"""
max_relation = max(max_relation, 1) # At least Type and Next relations
max_relation = max(max_relation, 1) # At least Next relation
num_relations = max_relation + 1

with open(output_file, "w") as f:
f.write(f"{num_relations}\n")
f.write("Type\t0\n")
f.write("Next\t1\n")
f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations))
if self.mode == "llvm":
# LLVM IR has Type relation at 0
f.write("Type\t0\n")
f.write("Next\t1\n")
f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations))
else:
# MIR doesn't have Type relation, starts with Next at 0
f.write("Next\t0\n")
f.writelines(f"Arg{i-1}\t{i}\n" for i in range(1, num_relations))


def main():
"""Main entry point"""
parser = argparse.ArgumentParser(
description="Generate IR2Vec triplets from LLVM IR files",
description="Generate IR2Vec or MIR2Vec triplets from LLVM IR or Machine IR files",
formatter_class=argparse.RawDescriptionHelpFormatter,
)

Expand All @@ -253,16 +340,25 @@ def main():
parser.add_argument(
"num_optimizations",
type=int,
help="Number of optimization levels to apply (1-6)",
nargs="?",
default=1,
help="Number of optimization levels to apply (1-6) for LLVM IR mode",
)
parser.add_argument(
"ll_file_list",
"input_file_list",
type=Path,
help="File containing list of LLVM IR files to process",
help="File containing list of LLVM IR or MIR files to process",
)
parser.add_argument(
"output_dir", type=Path, help="Output directory for generated files"
)
parser.add_argument(
"--mode",
type=str,
choices=["llvm", "mir"],
default="llvm",
help="Operation mode: 'llvm' for LLVM IR (default) or 'mir' for Machine IR",
)
parser.add_argument(
"-j",
"--max-workers",
Expand Down Expand Up @@ -296,8 +392,9 @@ def main():
args.num_optimizations,
args.output_dir,
args.max_workers,
args.mode,
)
generator.generate_triplets(args.ll_file_list)
generator.generate_triplets(args.input_file_list)


if __name__ == "__main__":
Expand Down
Loading