RL compilation step wit GNN#563
RL compilation step wit GNN#563antotu wants to merge 180 commits intomunich-quantum-toolkit:mainfrom
Conversation
…and test all done
…into gnn-branch
…into gnn-branch
…into gnn-branch
Signed-off-by: Antonio Tudisco <anto.tu98@hotmail.it>
Signed-off-by: Antonio Tudisco <anto.tu98@hotmail.it>
…into gnn-branch
Signed-off-by: Antonio Tudisco <anto.tu98@hotmail.it>
📝 WalkthroughSummary by CodeRabbitRelease Notes
✏️ Tip: You can customize this high-level summary in your review settings. WalkthroughThis PR integrates Graph Neural Network (GNN) capabilities into the MQT Predictor module. It adds torch, torch-geometric, optuna, and safetensors dependencies; implements GNN architecture with GraphConvolutionSage and MLP components; introduces GNN-specific training, evaluation, and prediction functions; and extends the prediction pipeline to support both classical (Random Forest) and graph-based (GNN) modeling pathways with Optuna-powered hyperparameter optimization. Changes
Sequence DiagramssequenceDiagram
participant Client as Predictor Setup
participant Optuna as Optuna Study
participant GNN as GNN Model
participant Loader as DataLoader
participant Eval as Evaluation
Client->>Optuna: create_study() for hyperparameter optimization
Optuna->>Optuna: sampler.ask() for trial parameters
Optuna->>GNN: instantiate with sampled hyperparams
GNN->>Loader: process training graphs
Loader->>GNN: batch forward pass
GNN->>Eval: compute loss
Eval-->>Optuna: return trial value
Optuna->>Optuna: optimize (repeat until convergence)
Optuna-->>Client: return best hyperparams & model
sequenceDiagram
participant Circuit as Quantum Circuit
participant DAG as DAG Constructor
participant GNN as GNN Encoder
participant MLP as MLP Head
participant Output as Device Scores
Circuit->>DAG: decompose & extract graph structure
DAG->>DAG: compute node features (gate types, qubits)
DAG-->>GNN: node_vector, edge_index
GNN->>GNN: apply SAGEConv layers with residuals
GNN->>GNN: optional SAGPooling
GNN->>GNN: global mean pooling
GNN-->>MLP: graph embedding
MLP->>MLP: apply hidden layers with activation
MLP-->>Output: per-device logits/scores
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@pyproject.toml`:
- Around line 47-50: The pyproject declares torch only for macOS x86 Python<3.13
but the codebase (files src/mqt/predictor/ml/gnn.py,
src/mqt/predictor/predictor.py, src/mqt/predictor/helper.py) imports torch
unconditionally; fix by either adding an unconditional torch dependency in
pyproject.toml (remove the platform marker so "torch>=2.7.0,<2.8.0" is installed
everywhere) or make torch optional by adding an extras_require (e.g. "torch":
[...]) and wrapping all torch imports in import guards/try-except with clear
fallback errors in the referenced modules (gnn.py, predictor.py, helper.py) to
raise a helpful message if torch is missing; update the package metadata and the
three modules consistently so imports match the dependency strategy.
In `@src/mqt/predictor/_version.py`:
- Around line 9-11: Remove the autogenerated file src/mqt/predictor/_version.py
from version control (stop tracking it and delete from the repo tree), add an
entry for src/mqt/predictor/_version.py to .gitignore so it isn't committed
again, and ensure the project relies on the hatch-vcs/hatch hook to regenerate
the file at build/time; reference the autogenerated indicator in
src/mqt/predictor/_version.py and the hatch-vcs hook to verify regeneration
works after removal.
In `@src/mqt/predictor/ml/predictor.py`:
- Around line 749-752: The GridSearchCV call can fail when a class has only one
sample because StratifiedKFold requires n_splits between 2 and the minimum class
count; before creating GridSearchCV in predictor.py, check min_class and set the
CV strategy accordingly: if min_class >= 2 compute num_cv = max(2,
min(original_num_cv, min_class)) and use that (allowing StratifiedKFold via
GridSearchCV), but if min_class < 2 then do not rely on StratifiedKFold — either
skip GridSearchCV and call mdl.fit(...) directly or instantiate GridSearchCV
with a non-stratified splitter (e.g., KFold) and appropriate n_splits (>=2 and
<=len(y_train)); update the code around variables num_cv, min_class,
GridSearchCV, and mdl to implement this guard and fallback.
In `@tests/device_selection/test_predictor_ml.py`:
- Around line 42-43: The function signature that currently accepts
path_uncompiled_circuits: Path, path_compiled_circuits: Path, gnn: bool,
verbose: bool should make the boolean flags keyword-only to avoid positional
ambiguity; update the signature by inserting a positional-only separator (e.g.
add ", *," before gnn) so gnn and verbose must be passed by keyword (and add
sensible defaults like gnn: bool = False, verbose: bool = False if appropriate),
and apply the same change to the other occurrence referenced in the file.
In `@tests/hellinger_distance/test_estimated_hellinger_distance.py`:
- Line 165: The test function signature test_train_model_and_predict currently
allows the boolean flag verbose to be passed positionally; update the signature
to make boolean flags keyword-only (e.g., def
test_train_model_and_predict(device: Target, model_type: str, *, verbose: bool)
-> None) to satisfy Ruff FBT001 and avoid positional ambiguity; apply the same
change to the other test function(s) with boolean flags reported in the review
(the ones around the later tests) so all boolean parameters are keyword-only.
| "optuna>=4.0.0", | ||
| "torch-geometric>=2.7.0,<2.8.0", | ||
| "torch>=2.7.0,<2.8.0; sys_platform == 'darwin' and 'x86_64' in platform_machine and python_version < '3.13'", # Restrict torch v2.3.0 for macOS x86 since it is not supported anymore. | ||
| "safetensors>=0.3.0", |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Verify pinned versions exist on PyPI
echo "=== Checking torch 2.7.0 on PyPI ==="
if curl -s https://pypi.org/pypi/torch/json 2>/dev/null | jq -r '.releases | keys[]' | grep -q '^2\.7\.0$'; then
echo "✓ torch 2.7.0 found on PyPI"
else
echo "✗ torch 2.7.0 NOT found on PyPI"
fi
echo ""
echo "=== Checking torch-geometric 2.7.0 on PyPI ==="
if curl -s https://pypi.org/pypi/torch-geometric/json 2>/dev/null | jq -r '.releases | keys[]' | grep -q '^2\.7\.0$'; then
echo "✓ torch-geometric 2.7.0 found on PyPI"
else
echo "✗ torch-geometric 2.7.0 NOT found on PyPI"
fi
echo ""
echo "=== torch-geometric dependencies on torch ==="
curl -s https://pypi.org/pypi/torch-geometric/2.7.0/json 2>/dev/null | jq -r '.info.requires_dist[]' 2>/dev/null | grep -i 'torch' || echo "No torch dependency found in metadata"Repository: munich-quantum-toolkit/predictor
Length of output: 493
🏁 Script executed:
# Search for GNN imports and usage
echo "=== Searching for GNN-related imports ==="
find . -type f \( -name "*.py" \) | head -20
echo ""
echo "=== Searching for torch imports ==="
rg "import torch" --type py -A 2 -B 2 | head -100
echo ""
echo "=== Checking for conditional torch imports ==="
rg "try.*import torch|except.*ImportError" --type py -B 2 -A 2Repository: munich-quantum-toolkit/predictor
Length of output: 2922
🏁 Script executed:
cat pyproject.toml | head -100Repository: munich-quantum-toolkit/predictor
Length of output: 3350
🏁 Script executed:
# Check the full dependencies section
rg "^\[project\]" -A 100 pyproject.toml | head -150Repository: munich-quantum-toolkit/predictor
Length of output: 3277
🏁 Script executed:
# Look for all torch-related entries in pyproject.toml
rg "torch" pyproject.tomlRepository: munich-quantum-toolkit/predictor
Length of output: 827
Torch dependency is platform-conditional but imported unconditionally—will break on Linux, Windows, macOS ARM, and Python 3.13+.
torch is restricted to macOS x86 with Python < 3.13, but src/mqt/predictor/ml/gnn.py, predictor.py, and helper.py all import it unconditionally. Since torch-geometric doesn't declare torch as a dependency, installations on Linux, Windows, macOS ARM, or Python 3.13+ will fail with ImportError. Add an unconditional torch requirement or implement import guards with optional extras.
🤖 Prompt for AI Agents
In `@pyproject.toml` around lines 47 - 50, The pyproject declares torch only for
macOS x86 Python<3.13 but the codebase (files src/mqt/predictor/ml/gnn.py,
src/mqt/predictor/predictor.py, src/mqt/predictor/helper.py) imports torch
unconditionally; fix by either adding an unconditional torch dependency in
pyproject.toml (remove the platform marker so "torch>=2.7.0,<2.8.0" is installed
everywhere) or make torch optional by adding an extras_require (e.g. "torch":
[...]) and wrapping all torch imports in import guards/try-except with clear
fallback errors in the referenced modules (gnn.py, predictor.py, helper.py) to
raise a helpful message if torch is missing; update the package metadata and the
three modules consistently so imports match the dependency strategy.
| # file generated by setuptools-scm | ||
| # don't change, don't track in version control | ||
|
|
There was a problem hiding this comment.
Autogenerated _version.py should not be tracked.
The header indicates this file is generated by setuptools-scm; committing it risks stale versions and merge conflicts. Please remove it from VCS and add it to .gitignore, relying on the hatch-vcs hook to regenerate it. Based on learnings, this file should be excluded from version control.
🤖 Prompt for AI Agents
In `@src/mqt/predictor/_version.py` around lines 9 - 11, Remove the autogenerated
file src/mqt/predictor/_version.py from version control (stop tracking it and
delete from the repo tree), add an entry for src/mqt/predictor/_version.py to
.gitignore so it isn't committed again, and ensure the project relies on the
hatch-vcs/hatch hook to regenerate the file at build/time; reference the
autogenerated indicator in src/mqt/predictor/_version.py and the hatch-vcs hook
to verify regeneration works after removal.
| num_cv = min(len(training_data.y_train), 5) | ||
| min_class = min(Counter(training_data.y_train).values()) | ||
| num_cv = max(2, min(num_cv, min_class)) | ||
| mdl = GridSearchCV(mdl, tree_param, cv=num_cv, n_jobs=8).fit(training_data.X_train, training_data.y_train) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /tmp && git -C /root/repo ls-files src/mqt/predictor/ml/predictor.pyRepository: munich-quantum-toolkit/predictor
Length of output: 141
🏁 Script executed:
# Get context around lines 749-752
head -n 760 /root/repo/src/mqt/predictor/ml/predictor.py | tail -n 30 | cat -nRepository: munich-quantum-toolkit/predictor
Length of output: 179
🏁 Script executed:
# Check the broader function context
head -n 800 /root/repo/src/mqt/predictor/ml/predictor.py | tail -n 100 | cat -nRepository: munich-quantum-toolkit/predictor
Length of output: 179
🏁 Script executed:
git ls-files | grep predictor.pyRepository: munich-quantum-toolkit/predictor
Length of output: 170
🏁 Script executed:
# Also check the current working directory
pwd && ls -laRepository: munich-quantum-toolkit/predictor
Length of output: 1499
🏁 Script executed:
# List the repository structure
git ls-files | head -20Repository: munich-quantum-toolkit/predictor
Length of output: 570
🏁 Script executed:
# Read the file around lines 749-752
sed -n '740,760p' src/mqt/predictor/ml/predictor.py | cat -nRepository: munich-quantum-toolkit/predictor
Length of output: 1200
🏁 Script executed:
# Get more context - find the function this code belongs to
sed -n '700,760p' src/mqt/predictor/ml/predictor.py | cat -nRepository: munich-quantum-toolkit/predictor
Length of output: 3255
🏁 Script executed:
# Check the imports at the top of the file
head -n 50 src/mqt/predictor/ml/predictor.py | cat -nRepository: munich-quantum-toolkit/predictor
Length of output: 1886
🏁 Script executed:
# Check if there are any tests for this function
find tests -name "*.py" -type f | xargs grep -l "train_random_forest_model" | head -5Repository: munich-quantum-toolkit/predictor
Length of output: 137
🏁 Script executed:
# Search for any existing handling of small class sizes
grep -n "min_class\|StratifiedKFold\|cross_val" src/mqt/predictor/ml/predictor.pyRepository: munich-quantum-toolkit/predictor
Length of output: 196
🏁 Script executed:
# Check the test file to understand the expected behavior
cat -n tests/hellinger_distance/test_estimated_hellinger_distance.py | head -60Repository: munich-quantum-toolkit/predictor
Length of output: 2185
🏁 Script executed:
# Look for any other CV-related code or tests
find tests -name "*.py" -type f | xargs grep -l "GridSearchCV\|cv=" | head -5Repository: munich-quantum-toolkit/predictor
Length of output: 58
🏁 Script executed:
# Check if there's any documentation or handling of edge cases in the codebase
grep -rn "min_class\|class.*<.*2\|samples.*per.*class" tests/ src/ | head -10Repository: munich-quantum-toolkit/predictor
Length of output: 264
🏁 Script executed:
# Check scikit-learn version to understand GridSearchCV behavior
grep -i "scikit-learn\|sklearn" src/mqt/predictor/ml/predictor.py pyproject.tomlRepository: munich-quantum-toolkit/predictor
Length of output: 658
🏁 Script executed:
# Let's verify the GridSearchCV behavior with a simple test
python3 << 'EOF'
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
import inspect
# Check the default cv behavior
print("GridSearchCV default cv parameter:")
print(GridSearchCV.__init__.__doc__[:500])
print("\n---\n")
# Check if StratifiedKFold is used for classifiers
from sklearn.model_selection import _split
print("Let's check what happens with cv=2 and 1 sample per class:")
# Test with minimal data
import numpy as np
from collections import Counter
# Create a minimal problematic case: 2 samples, 1 sample per class
X = np.array([[1], [2]])
y = np.array([0, 1])
print(f"Data shape: X={X.shape}, y={y.shape}")
print(f"Class distribution: {dict(Counter(y))}")
try:
# Try with cv=2
mdl = RandomForestClassifier(random_state=0)
gs = GridSearchCV(mdl, {"n_estimators": [10]}, cv=2)
gs.fit(X, y)
print("cv=2 with 1 sample per class: SUCCESS")
except Exception as e:
print(f"cv=2 with 1 sample per class: ERROR - {type(e).__name__}: {e}")
print("\n---\n")
# Now test with min_class < 2
X = np.array([[1], [2], [3]])
y = np.array([0, 1, 0]) # min_class = 1
print(f"Data shape: X={X.shape}, y={y.shape}")
print(f"Class distribution: {dict(Counter(y))}")
try:
mdl = RandomForestClassifier(random_state=0)
gs = GridSearchCV(mdl, {"n_estimators": [10]}, cv=2)
gs.fit(X, y)
print("cv=2 with min_class=1: SUCCESS")
except Exception as e:
print(f"cv=2 with min_class=1: ERROR - {type(e).__name__}: {e}")
EOFRepository: munich-quantum-toolkit/predictor
Length of output: 162
🌐 Web query:
GridSearchCV StratifiedKFold cv parameter classifier sklearn minimum samples per class error
💡 Result:
Cause: StratifiedKFold (used by GridSearchCV for classifiers when cv is an int) requires each class to appear in every fold, so n_splits cannot exceed the number of samples in the least-populated class. If a class has fewer samples than n_splits you get the warning/error. [1][2]
Fixes (pick one):
- Reduce n_splits to <= min_count = y.value_counts().min().
Example: n_splits = min(5, y.value_counts().min()). - Use non‑stratified folds: cv=KFold(n_splits=5).
- Use StratifiedShuffleSplit with an appropriate test_size (keeps class proportions but doesn’t require every fold to contain every class).
- Merge rare classes, oversample (SMOTE/upsampling) or remove extremely rare classes so min_count >= n_splits.
- Use GroupKFold if you have grouping that must be preserved.
Quick code examples:
- set n_splits dynamically:
from sklearn.model_selection import GridSearchCV, StratifiedKFold
n = min(5, y.value_counts().min())
cv = StratifiedKFold(n_splits=n, shuffle=True, random_state=0)
gs = GridSearchCV(estimator, param_grid, cv=cv) - use KFold:
from sklearn.model_selection import KFold
gs = GridSearchCV(estimator, param_grid, cv=KFold(5, shuffle=True))
References: scikit-learn defaults and the warning explanation; community Q&A with solutions. [1][2][3]
Sources:
[1] Stack Overflow — explanation that GridSearchCV uses StratifiedKFold for classifiers and requirement that each class appear in every split. (stackoverflow.com)
[2] Explanation of the error cause (n_splits > samples in a class). (m.php.cn)
[3] scikit-learn discussion/issue about behavior and constraints of StratifiedKFold. (github.com)
Citations:
- 1: https://stackoverflow.com/questions/74445334/userwarning-the-least-populated-class-in-y-has-only-1-members-which-is-less-th?utm_source=openai
- 2: https://m.php.cn/en/faq/1796910302.html?utm_source=openai
- 3: [Feature-Request] Add a flag to StratifiedKFold to force classes with only 1 sample in training scikit-learn/scikit-learn#10767
🏁 Script executed:
# Search for sklearn documentation or comments about cv handling in the codebase
grep -rn "StratifiedKFold\|cv=" src/mqt/predictor/ml/ | head -15Repository: munich-quantum-toolkit/predictor
Length of output: 228
🏁 Script executed:
# Check if there are any related issues or TODOs
grep -rn "TODO\|FIXME\|BUG\|HACK" src/mqt/predictor/ml/predictor.py | grep -i "cv\|class\|fold"Repository: munich-quantum-toolkit/predictor
Length of output: 58
GridSearchCV with StratifiedKFold will fail when min_class < 2.
When min_class equals 1, the code forces num_cv = max(2, min(num_cv, 1)) = 2, which violates scikit-learn's StratifiedKFold requirement that n_splits ≤ min(class_counts). GridSearchCV uses StratifiedKFold by default for classifiers, so this will raise an error at fit time rather than being caught during setup.
Add an explicit guard before GridSearchCV instantiation:
Suggested fix for small class counts
num_cv = min(len(training_data.y_train), 5)
min_class = min(Counter(training_data.y_train).values())
+if min_class < 2:
+ msg = "Not enough samples per class for cross-validation."
+ raise ValueError(msg)
-num_cv = max(2, min(num_cv, min_class))
+num_cv = min(num_cv, min_class)🤖 Prompt for AI Agents
In `@src/mqt/predictor/ml/predictor.py` around lines 749 - 752, The GridSearchCV
call can fail when a class has only one sample because StratifiedKFold requires
n_splits between 2 and the minimum class count; before creating GridSearchCV in
predictor.py, check min_class and set the CV strategy accordingly: if min_class
>= 2 compute num_cv = max(2, min(original_num_cv, min_class)) and use that
(allowing StratifiedKFold via GridSearchCV), but if min_class < 2 then do not
rely on StratifiedKFold — either skip GridSearchCV and call mdl.fit(...)
directly or instantiate GridSearchCV with a non-stratified splitter (e.g.,
KFold) and appropriate n_splits (>=2 and <=len(y_train)); update the code around
variables num_cv, min_class, GridSearchCV, and mdl to implement this guard and
fallback.
| path_uncompiled_circuits: Path, path_compiled_circuits: Path, gnn: bool, verbose: bool | ||
| ) -> None: |
There was a problem hiding this comment.
🧹 Nitpick | 🔵 Trivial
Make boolean flags keyword-only to avoid positional ambiguity (Ruff FBT001).
♻️ Suggested signature tweak
-def test_setup_device_predictor_with_prediction(
- path_uncompiled_circuits: Path, path_compiled_circuits: Path, gnn: bool, verbose: bool
-) -> None:
+def test_setup_device_predictor_with_prediction(
+ path_uncompiled_circuits: Path, path_compiled_circuits: Path, *, gnn: bool, verbose: bool
+) -> None:
@@
-def test_get_prepared_training_data_false_input(gnn: bool) -> None:
+def test_get_prepared_training_data_false_input(*, gnn: bool) -> None:Also applies to: 162-162
🧰 Tools
🪛 Ruff (0.14.13)
42-42: Boolean-typed positional argument in function definition
(FBT001)
42-42: Boolean-typed positional argument in function definition
(FBT001)
🤖 Prompt for AI Agents
In `@tests/device_selection/test_predictor_ml.py` around lines 42 - 43, The
function signature that currently accepts path_uncompiled_circuits: Path,
path_compiled_circuits: Path, gnn: bool, verbose: bool should make the boolean
flags keyword-only to avoid positional ambiguity; update the signature by
inserting a positional-only separator (e.g. add ", *," before gnn) so gnn and
verbose must be passed by keyword (and add sensible defaults like gnn: bool =
False, verbose: bool = False if appropriate), and apply the same change to the
other occurrence referenced in the file.
| @pytest.mark.parametrize( | ||
| ("model_type", "verbose"), [("rf", False), ("gnn", False), ("gnn", True)], ids=["rf", "gnn", "gnn_verbose"] | ||
| ) | ||
| def test_train_model_and_predict(device: Target, model_type: str, verbose: bool) -> None: |
There was a problem hiding this comment.
🧹 Nitpick | 🔵 Trivial
Make boolean flags keyword-only to avoid positional ambiguity (Ruff FBT001).
♻️ Suggested signature tweak
-def test_train_model_and_predict(device: Target, model_type: str, verbose: bool) -> None:
+def test_train_model_and_predict(device: Target, model_type: str, *, verbose: bool) -> None:
@@
-def test_train_and_qcompile_with_hellinger_model(
- source_path: Path, target_path: Path, device: Target, model_type: str, verbose: bool
-) -> None:
+def test_train_and_qcompile_with_hellinger_model(
+ source_path: Path, target_path: Path, device: Target, model_type: str, *, verbose: bool
+) -> None:Also applies to: 232-233
🧰 Tools
🪛 Ruff (0.14.13)
165-165: Boolean-typed positional argument in function definition
(FBT001)
🤖 Prompt for AI Agents
In `@tests/hellinger_distance/test_estimated_hellinger_distance.py` at line 165,
The test function signature test_train_model_and_predict currently allows the
boolean flag verbose to be passed positionally; update the signature to make
boolean flags keyword-only (e.g., def test_train_model_and_predict(device:
Target, model_type: str, *, verbose: bool) -> None) to satisfy Ruff FBT001 and
avoid positional ambiguity; apply the same change to the other test function(s)
with boolean flags reported in the review (the ones around the later tests) so
all boolean parameters are keyword-only.
Description
Please include a summary of the change and, if applicable, which issue is fixed.
Please also include relevant motivation and context.
List any dependencies that are required for this change.
Fixes #(issue)
Checklist: