Skip to content

RL compilation step wit GNN#563

Open
antotu wants to merge 180 commits intomunich-quantum-toolkit:mainfrom
antotu:RL-compilation-step
Open

RL compilation step wit GNN#563
antotu wants to merge 180 commits intomunich-quantum-toolkit:mainfrom
antotu:RL-compilation-step

Conversation

@antotu
Copy link

@antotu antotu commented Jan 8, 2026

Description

Please include a summary of the change and, if applicable, which issue is fixed.
Please also include relevant motivation and context.
List any dependencies that are required for this change.

Fixes #(issue)

Checklist:

  • The pull request only contains commits that are focused and relevant to this change.
  • I have added appropriate tests that cover the new/changed functionality.
  • I have updated the documentation to reflect these changes.
  • I have added entries to the changelog for any noteworthy additions, changes, fixes, or removals.
  • I have added migration instructions to the upgrade guide (if needed).
  • The changes follow the project's style guidelines and introduce no new warnings.
  • The changes are fully tested and pass the CI checks.
  • I have reviewed my own code changes.

antotu and others added 30 commits August 19, 2025 13:47
@antotu antotu marked this pull request as ready for review January 20, 2026 12:17
@coderabbitai
Copy link

coderabbitai bot commented Jan 20, 2026

📝 Walkthrough

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Graph Neural Network (GNN) support for quantum device prediction, offering an alternative to the existing random forest models
    • Introduced Hellinger distance estimation using GNN models
    • Added automatic version tracking
  • Chores

    • Updated dependencies to include PyTorch, Torch Geometric, Optuna, and Safetensors
    • Enhanced type checking and linting configurations
    • Expanded test coverage for new GNN functionality

✏️ Tip: You can customize this high-level summary in your review settings.

Walkthrough

This PR integrates Graph Neural Network (GNN) capabilities into the MQT Predictor module. It adds torch, torch-geometric, optuna, and safetensors dependencies; implements GNN architecture with GraphConvolutionSage and MLP components; introduces GNN-specific training, evaluation, and prediction functions; and extends the prediction pipeline to support both classical (Random Forest) and graph-based (GNN) modeling pathways with Optuna-powered hyperparameter optimization.

Changes

Cohort / File(s) Summary
Configuration & Dependencies
pyproject.toml
Added runtime dependencies: optuna, torch, torch-geometric, safetensors. Expanded mypy overrides for torch and safetensors modules. Introduced dedicated GNN module mypy override. Updated Ruff config with namespace-packages, preview, and fix options. Extended typo allowances.
Version Tracking
src/mqt/predictor/_version.py
New module exporting version metadata (version, \version, version_tuple, \version_tuple, commit_id, \commit_id) generated by setuptools-scm with TYPE_CHECKING type annotations.
GNN Core Implementation
src/mqt/predictor/ml/gnn.py
New GraphConvolutionSage class with SAGEConv stacks, optional residual connections, bidirectional aggregation, and SAGPooling. New GNN class combining encoder and MLP head for graph-based predictions. Includes per-layer normalization, dropout, and configurable activation functions.
ML Pipeline Updates
src/mqt/predictor/ml/helper.py, src/mqt/predictor/ml/predictor.py
Added GNN-specific path resolution, DAG construction from circuits, and evaluation/training functions for classification and regression. Extended TrainingData to support mixed classical and graph data. Refactored predictor to branch on gnn flag; added Optuna-based hyperparameter optimization and GNN training pipeline with safetensors-based serialization.
Module Exports
src/mqt/predictor/ml/__init__.py
Reformatted \all to multi-line syntax; no functional changes to exported entities.
Hellinger Distance Utilities
src/mqt/predictor/hellinger/utils.py
Updated get_hellinger_model_path to accept optional gnn flag; adjusted path construction to include trained_model subdirectory and select filename based on model type.
RL Environment
src/mqt/predictor/rl/predictorenv.py
Updated TKET-to-Qiskit conversion in _apply_tket_action to use replace_implicit_swaps=True for explicit SWAP gate handling.
Test Coverage
tests/compilation/test_predictor_rl.py, tests/device_selection/test_helper_ml.py, tests/device_selection/test_predictor_ml.py, tests/hellinger_distance/test_estimated_hellinger_distance.py
Added parameterized tests for GNN vs. RF models. New DAG construction tests. Extended predictor tests for multi-device GNN training and per-model validation. Introduced cleanup fixtures and dynamic device parameterization.

Sequence Diagrams

sequenceDiagram
    participant Client as Predictor Setup
    participant Optuna as Optuna Study
    participant GNN as GNN Model
    participant Loader as DataLoader
    participant Eval as Evaluation

    Client->>Optuna: create_study() for hyperparameter optimization
    Optuna->>Optuna: sampler.ask() for trial parameters
    Optuna->>GNN: instantiate with sampled hyperparams
    GNN->>Loader: process training graphs
    Loader->>GNN: batch forward pass
    GNN->>Eval: compute loss
    Eval-->>Optuna: return trial value
    Optuna->>Optuna: optimize (repeat until convergence)
    Optuna-->>Client: return best hyperparams & model
Loading
sequenceDiagram
    participant Circuit as Quantum Circuit
    participant DAG as DAG Constructor
    participant GNN as GNN Encoder
    participant MLP as MLP Head
    participant Output as Device Scores

    Circuit->>DAG: decompose & extract graph structure
    DAG->>DAG: compute node features (gate types, qubits)
    DAG-->>GNN: node_vector, edge_index
    GNN->>GNN: apply SAGEConv layers with residuals
    GNN->>GNN: optional SAGPooling
    GNN->>GNN: global mean pooling
    GNN-->>MLP: graph embedding
    MLP->>MLP: apply hidden layers with activation
    MLP-->>Output: per-device logits/scores
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • DDSIM #770: Updates mypy configuration in pyproject.toml to accommodate new dependencies, intersecting with this PR's mypy override expansions.
  • Bench #794: Modifies static type-checking configuration (pyproject.toml and CI) and may conflict with mypy vs. type-checker tool changes introduced here.

Poem

🐰 A graph of gates, so nodes galore,
SAGEConv stacks to explore,
With Optuna's tune and pooling deep,
GNN dreams the circuit will keep! 🌱✨

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is entirely the repository's template with no concrete information filled in. No summary, motivation, dependencies, issue references, or checklist items are completed. Provide a detailed description of the changes including a summary, relevant motivation/context, list of new dependencies (torch, torch-geometric, optuna, safetensors), and mark completed checklist items.
Title check ❓ Inconclusive The title 'RL compilation step wit GNN' appears to describe GNN integration for RL compilation, but contains a typo ('wit' instead of 'with') and is vague about the specific changes being introduced. Correct the typo and clarify the title to better reflect the main changes, such as: 'Add GNN support for device prediction in RL compilation' or similar.
✅ Passed checks (1 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 94.64% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@pyproject.toml`:
- Around line 47-50: The pyproject declares torch only for macOS x86 Python<3.13
but the codebase (files src/mqt/predictor/ml/gnn.py,
src/mqt/predictor/predictor.py, src/mqt/predictor/helper.py) imports torch
unconditionally; fix by either adding an unconditional torch dependency in
pyproject.toml (remove the platform marker so "torch>=2.7.0,<2.8.0" is installed
everywhere) or make torch optional by adding an extras_require (e.g. "torch":
[...]) and wrapping all torch imports in import guards/try-except with clear
fallback errors in the referenced modules (gnn.py, predictor.py, helper.py) to
raise a helpful message if torch is missing; update the package metadata and the
three modules consistently so imports match the dependency strategy.

In `@src/mqt/predictor/_version.py`:
- Around line 9-11: Remove the autogenerated file src/mqt/predictor/_version.py
from version control (stop tracking it and delete from the repo tree), add an
entry for src/mqt/predictor/_version.py to .gitignore so it isn't committed
again, and ensure the project relies on the hatch-vcs/hatch hook to regenerate
the file at build/time; reference the autogenerated indicator in
src/mqt/predictor/_version.py and the hatch-vcs hook to verify regeneration
works after removal.

In `@src/mqt/predictor/ml/predictor.py`:
- Around line 749-752: The GridSearchCV call can fail when a class has only one
sample because StratifiedKFold requires n_splits between 2 and the minimum class
count; before creating GridSearchCV in predictor.py, check min_class and set the
CV strategy accordingly: if min_class >= 2 compute num_cv = max(2,
min(original_num_cv, min_class)) and use that (allowing StratifiedKFold via
GridSearchCV), but if min_class < 2 then do not rely on StratifiedKFold — either
skip GridSearchCV and call mdl.fit(...) directly or instantiate GridSearchCV
with a non-stratified splitter (e.g., KFold) and appropriate n_splits (>=2 and
<=len(y_train)); update the code around variables num_cv, min_class,
GridSearchCV, and mdl to implement this guard and fallback.

In `@tests/device_selection/test_predictor_ml.py`:
- Around line 42-43: The function signature that currently accepts
path_uncompiled_circuits: Path, path_compiled_circuits: Path, gnn: bool,
verbose: bool should make the boolean flags keyword-only to avoid positional
ambiguity; update the signature by inserting a positional-only separator (e.g.
add ", *," before gnn) so gnn and verbose must be passed by keyword (and add
sensible defaults like gnn: bool = False, verbose: bool = False if appropriate),
and apply the same change to the other occurrence referenced in the file.

In `@tests/hellinger_distance/test_estimated_hellinger_distance.py`:
- Line 165: The test function signature test_train_model_and_predict currently
allows the boolean flag verbose to be passed positionally; update the signature
to make boolean flags keyword-only (e.g., def
test_train_model_and_predict(device: Target, model_type: str, *, verbose: bool)
-> None) to satisfy Ruff FBT001 and avoid positional ambiguity; apply the same
change to the other test function(s) with boolean flags reported in the review
(the ones around the later tests) so all boolean parameters are keyword-only.

Comment on lines +47 to +50
"optuna>=4.0.0",
"torch-geometric>=2.7.0,<2.8.0",
"torch>=2.7.0,<2.8.0; sys_platform == 'darwin' and 'x86_64' in platform_machine and python_version < '3.13'", # Restrict torch v2.3.0 for macOS x86 since it is not supported anymore.
"safetensors>=0.3.0",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Verify pinned versions exist on PyPI
echo "=== Checking torch 2.7.0 on PyPI ==="
if curl -s https://pypi.org/pypi/torch/json 2>/dev/null | jq -r '.releases | keys[]' | grep -q '^2\.7\.0$'; then
  echo "✓ torch 2.7.0 found on PyPI"
else
  echo "✗ torch 2.7.0 NOT found on PyPI"
fi

echo ""
echo "=== Checking torch-geometric 2.7.0 on PyPI ==="
if curl -s https://pypi.org/pypi/torch-geometric/json 2>/dev/null | jq -r '.releases | keys[]' | grep -q '^2\.7\.0$'; then
  echo "✓ torch-geometric 2.7.0 found on PyPI"
else
  echo "✗ torch-geometric 2.7.0 NOT found on PyPI"
fi

echo ""
echo "=== torch-geometric dependencies on torch ==="
curl -s https://pypi.org/pypi/torch-geometric/2.7.0/json 2>/dev/null | jq -r '.info.requires_dist[]' 2>/dev/null | grep -i 'torch' || echo "No torch dependency found in metadata"

Repository: munich-quantum-toolkit/predictor

Length of output: 493


🏁 Script executed:

# Search for GNN imports and usage
echo "=== Searching for GNN-related imports ==="
find . -type f \( -name "*.py" \) | head -20

echo ""
echo "=== Searching for torch imports ==="
rg "import torch" --type py -A 2 -B 2 | head -100

echo ""
echo "=== Checking for conditional torch imports ==="
rg "try.*import torch|except.*ImportError" --type py -B 2 -A 2

Repository: munich-quantum-toolkit/predictor

Length of output: 2922


🏁 Script executed:

cat pyproject.toml | head -100

Repository: munich-quantum-toolkit/predictor

Length of output: 3350


🏁 Script executed:

# Check the full dependencies section
rg "^\[project\]" -A 100 pyproject.toml | head -150

Repository: munich-quantum-toolkit/predictor

Length of output: 3277


🏁 Script executed:

# Look for all torch-related entries in pyproject.toml
rg "torch" pyproject.toml

Repository: munich-quantum-toolkit/predictor

Length of output: 827


Torch dependency is platform-conditional but imported unconditionally—will break on Linux, Windows, macOS ARM, and Python 3.13+.

torch is restricted to macOS x86 with Python < 3.13, but src/mqt/predictor/ml/gnn.py, predictor.py, and helper.py all import it unconditionally. Since torch-geometric doesn't declare torch as a dependency, installations on Linux, Windows, macOS ARM, or Python 3.13+ will fail with ImportError. Add an unconditional torch requirement or implement import guards with optional extras.

🤖 Prompt for AI Agents
In `@pyproject.toml` around lines 47 - 50, The pyproject declares torch only for
macOS x86 Python<3.13 but the codebase (files src/mqt/predictor/ml/gnn.py,
src/mqt/predictor/predictor.py, src/mqt/predictor/helper.py) imports torch
unconditionally; fix by either adding an unconditional torch dependency in
pyproject.toml (remove the platform marker so "torch>=2.7.0,<2.8.0" is installed
everywhere) or make torch optional by adding an extras_require (e.g. "torch":
[...]) and wrapping all torch imports in import guards/try-except with clear
fallback errors in the referenced modules (gnn.py, predictor.py, helper.py) to
raise a helpful message if torch is missing; update the package metadata and the
three modules consistently so imports match the dependency strategy.

Comment on lines +9 to +11
# file generated by setuptools-scm
# don't change, don't track in version control

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Autogenerated _version.py should not be tracked.

The header indicates this file is generated by setuptools-scm; committing it risks stale versions and merge conflicts. Please remove it from VCS and add it to .gitignore, relying on the hatch-vcs hook to regenerate it. Based on learnings, this file should be excluded from version control.

🤖 Prompt for AI Agents
In `@src/mqt/predictor/_version.py` around lines 9 - 11, Remove the autogenerated
file src/mqt/predictor/_version.py from version control (stop tracking it and
delete from the repo tree), add an entry for src/mqt/predictor/_version.py to
.gitignore so it isn't committed again, and ensure the project relies on the
hatch-vcs/hatch hook to regenerate the file at build/time; reference the
autogenerated indicator in src/mqt/predictor/_version.py and the hatch-vcs hook
to verify regeneration works after removal.

Comment on lines 749 to 752
num_cv = min(len(training_data.y_train), 5)
min_class = min(Counter(training_data.y_train).values())
num_cv = max(2, min(num_cv, min_class))
mdl = GridSearchCV(mdl, tree_param, cv=num_cv, n_jobs=8).fit(training_data.X_train, training_data.y_train)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cd /tmp && git -C /root/repo ls-files src/mqt/predictor/ml/predictor.py

Repository: munich-quantum-toolkit/predictor

Length of output: 141


🏁 Script executed:

# Get context around lines 749-752
head -n 760 /root/repo/src/mqt/predictor/ml/predictor.py | tail -n 30 | cat -n

Repository: munich-quantum-toolkit/predictor

Length of output: 179


🏁 Script executed:

# Check the broader function context
head -n 800 /root/repo/src/mqt/predictor/ml/predictor.py | tail -n 100 | cat -n

Repository: munich-quantum-toolkit/predictor

Length of output: 179


🏁 Script executed:

git ls-files | grep predictor.py

Repository: munich-quantum-toolkit/predictor

Length of output: 170


🏁 Script executed:

# Also check the current working directory
pwd && ls -la

Repository: munich-quantum-toolkit/predictor

Length of output: 1499


🏁 Script executed:

# List the repository structure
git ls-files | head -20

Repository: munich-quantum-toolkit/predictor

Length of output: 570


🏁 Script executed:

# Read the file around lines 749-752
sed -n '740,760p' src/mqt/predictor/ml/predictor.py | cat -n

Repository: munich-quantum-toolkit/predictor

Length of output: 1200


🏁 Script executed:

# Get more context - find the function this code belongs to
sed -n '700,760p' src/mqt/predictor/ml/predictor.py | cat -n

Repository: munich-quantum-toolkit/predictor

Length of output: 3255


🏁 Script executed:

# Check the imports at the top of the file
head -n 50 src/mqt/predictor/ml/predictor.py | cat -n

Repository: munich-quantum-toolkit/predictor

Length of output: 1886


🏁 Script executed:

# Check if there are any tests for this function
find tests -name "*.py" -type f | xargs grep -l "train_random_forest_model" | head -5

Repository: munich-quantum-toolkit/predictor

Length of output: 137


🏁 Script executed:

# Search for any existing handling of small class sizes
grep -n "min_class\|StratifiedKFold\|cross_val" src/mqt/predictor/ml/predictor.py

Repository: munich-quantum-toolkit/predictor

Length of output: 196


🏁 Script executed:

# Check the test file to understand the expected behavior
cat -n tests/hellinger_distance/test_estimated_hellinger_distance.py | head -60

Repository: munich-quantum-toolkit/predictor

Length of output: 2185


🏁 Script executed:

# Look for any other CV-related code or tests
find tests -name "*.py" -type f | xargs grep -l "GridSearchCV\|cv=" | head -5

Repository: munich-quantum-toolkit/predictor

Length of output: 58


🏁 Script executed:

# Check if there's any documentation or handling of edge cases in the codebase
grep -rn "min_class\|class.*<.*2\|samples.*per.*class" tests/ src/ | head -10

Repository: munich-quantum-toolkit/predictor

Length of output: 264


🏁 Script executed:

# Check scikit-learn version to understand GridSearchCV behavior
grep -i "scikit-learn\|sklearn" src/mqt/predictor/ml/predictor.py pyproject.toml

Repository: munich-quantum-toolkit/predictor

Length of output: 658


🏁 Script executed:

# Let's verify the GridSearchCV behavior with a simple test
python3 << 'EOF'
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
import inspect

# Check the default cv behavior
print("GridSearchCV default cv parameter:")
print(GridSearchCV.__init__.__doc__[:500])
print("\n---\n")

# Check if StratifiedKFold is used for classifiers
from sklearn.model_selection import _split
print("Let's check what happens with cv=2 and 1 sample per class:")

# Test with minimal data
import numpy as np
from collections import Counter

# Create a minimal problematic case: 2 samples, 1 sample per class
X = np.array([[1], [2]])
y = np.array([0, 1])

print(f"Data shape: X={X.shape}, y={y.shape}")
print(f"Class distribution: {dict(Counter(y))}")

try:
    # Try with cv=2
    mdl = RandomForestClassifier(random_state=0)
    gs = GridSearchCV(mdl, {"n_estimators": [10]}, cv=2)
    gs.fit(X, y)
    print("cv=2 with 1 sample per class: SUCCESS")
except Exception as e:
    print(f"cv=2 with 1 sample per class: ERROR - {type(e).__name__}: {e}")

print("\n---\n")

# Now test with min_class < 2
X = np.array([[1], [2], [3]])
y = np.array([0, 1, 0])  # min_class = 1

print(f"Data shape: X={X.shape}, y={y.shape}")
print(f"Class distribution: {dict(Counter(y))}")

try:
    mdl = RandomForestClassifier(random_state=0)
    gs = GridSearchCV(mdl, {"n_estimators": [10]}, cv=2)
    gs.fit(X, y)
    print("cv=2 with min_class=1: SUCCESS")
except Exception as e:
    print(f"cv=2 with min_class=1: ERROR - {type(e).__name__}: {e}")
EOF

Repository: munich-quantum-toolkit/predictor

Length of output: 162


🌐 Web query:

GridSearchCV StratifiedKFold cv parameter classifier sklearn minimum samples per class error

💡 Result:

Cause: StratifiedKFold (used by GridSearchCV for classifiers when cv is an int) requires each class to appear in every fold, so n_splits cannot exceed the number of samples in the least-populated class. If a class has fewer samples than n_splits you get the warning/error. [1][2]

Fixes (pick one):

  • Reduce n_splits to <= min_count = y.value_counts().min().
    Example: n_splits = min(5, y.value_counts().min()).
  • Use non‑stratified folds: cv=KFold(n_splits=5).
  • Use StratifiedShuffleSplit with an appropriate test_size (keeps class proportions but doesn’t require every fold to contain every class).
  • Merge rare classes, oversample (SMOTE/upsampling) or remove extremely rare classes so min_count >= n_splits.
  • Use GroupKFold if you have grouping that must be preserved.

Quick code examples:

  • set n_splits dynamically:
    from sklearn.model_selection import GridSearchCV, StratifiedKFold
    n = min(5, y.value_counts().min())
    cv = StratifiedKFold(n_splits=n, shuffle=True, random_state=0)
    gs = GridSearchCV(estimator, param_grid, cv=cv)
  • use KFold:
    from sklearn.model_selection import KFold
    gs = GridSearchCV(estimator, param_grid, cv=KFold(5, shuffle=True))

References: scikit-learn defaults and the warning explanation; community Q&A with solutions. [1][2][3]

Sources:
[1] Stack Overflow — explanation that GridSearchCV uses StratifiedKFold for classifiers and requirement that each class appear in every split. (stackoverflow.com)
[2] Explanation of the error cause (n_splits > samples in a class). (m.php.cn)
[3] scikit-learn discussion/issue about behavior and constraints of StratifiedKFold. (github.com)

Citations:


🏁 Script executed:

# Search for sklearn documentation or comments about cv handling in the codebase
grep -rn "StratifiedKFold\|cv=" src/mqt/predictor/ml/ | head -15

Repository: munich-quantum-toolkit/predictor

Length of output: 228


🏁 Script executed:

# Check if there are any related issues or TODOs
grep -rn "TODO\|FIXME\|BUG\|HACK" src/mqt/predictor/ml/predictor.py | grep -i "cv\|class\|fold"

Repository: munich-quantum-toolkit/predictor

Length of output: 58


GridSearchCV with StratifiedKFold will fail when min_class < 2.

When min_class equals 1, the code forces num_cv = max(2, min(num_cv, 1)) = 2, which violates scikit-learn's StratifiedKFold requirement that n_splits ≤ min(class_counts). GridSearchCV uses StratifiedKFold by default for classifiers, so this will raise an error at fit time rather than being caught during setup.

Add an explicit guard before GridSearchCV instantiation:

Suggested fix for small class counts
num_cv = min(len(training_data.y_train), 5)
min_class = min(Counter(training_data.y_train).values())
+if min_class < 2:
+    msg = "Not enough samples per class for cross-validation."
+    raise ValueError(msg)
-num_cv = max(2, min(num_cv, min_class))
+num_cv = min(num_cv, min_class)
🤖 Prompt for AI Agents
In `@src/mqt/predictor/ml/predictor.py` around lines 749 - 752, The GridSearchCV
call can fail when a class has only one sample because StratifiedKFold requires
n_splits between 2 and the minimum class count; before creating GridSearchCV in
predictor.py, check min_class and set the CV strategy accordingly: if min_class
>= 2 compute num_cv = max(2, min(original_num_cv, min_class)) and use that
(allowing StratifiedKFold via GridSearchCV), but if min_class < 2 then do not
rely on StratifiedKFold — either skip GridSearchCV and call mdl.fit(...)
directly or instantiate GridSearchCV with a non-stratified splitter (e.g.,
KFold) and appropriate n_splits (>=2 and <=len(y_train)); update the code around
variables num_cv, min_class, GridSearchCV, and mdl to implement this guard and
fallback.

Comment on lines +42 to +43
path_uncompiled_circuits: Path, path_compiled_circuits: Path, gnn: bool, verbose: bool
) -> None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Make boolean flags keyword-only to avoid positional ambiguity (Ruff FBT001).

♻️ Suggested signature tweak
-def test_setup_device_predictor_with_prediction(
-    path_uncompiled_circuits: Path, path_compiled_circuits: Path, gnn: bool, verbose: bool
-) -> None:
+def test_setup_device_predictor_with_prediction(
+    path_uncompiled_circuits: Path, path_compiled_circuits: Path, *, gnn: bool, verbose: bool
+) -> None:
@@
-def test_get_prepared_training_data_false_input(gnn: bool) -> None:
+def test_get_prepared_training_data_false_input(*, gnn: bool) -> None:

Also applies to: 162-162

🧰 Tools
🪛 Ruff (0.14.13)

42-42: Boolean-typed positional argument in function definition

(FBT001)


42-42: Boolean-typed positional argument in function definition

(FBT001)

🤖 Prompt for AI Agents
In `@tests/device_selection/test_predictor_ml.py` around lines 42 - 43, The
function signature that currently accepts path_uncompiled_circuits: Path,
path_compiled_circuits: Path, gnn: bool, verbose: bool should make the boolean
flags keyword-only to avoid positional ambiguity; update the signature by
inserting a positional-only separator (e.g. add ", *," before gnn) so gnn and
verbose must be passed by keyword (and add sensible defaults like gnn: bool =
False, verbose: bool = False if appropriate), and apply the same change to the
other occurrence referenced in the file.

@pytest.mark.parametrize(
("model_type", "verbose"), [("rf", False), ("gnn", False), ("gnn", True)], ids=["rf", "gnn", "gnn_verbose"]
)
def test_train_model_and_predict(device: Target, model_type: str, verbose: bool) -> None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Make boolean flags keyword-only to avoid positional ambiguity (Ruff FBT001).

♻️ Suggested signature tweak
-def test_train_model_and_predict(device: Target, model_type: str, verbose: bool) -> None:
+def test_train_model_and_predict(device: Target, model_type: str, *, verbose: bool) -> None:
@@
-def test_train_and_qcompile_with_hellinger_model(
-    source_path: Path, target_path: Path, device: Target, model_type: str, verbose: bool
-) -> None:
+def test_train_and_qcompile_with_hellinger_model(
+    source_path: Path, target_path: Path, device: Target, model_type: str, *, verbose: bool
+) -> None:

Also applies to: 232-233

🧰 Tools
🪛 Ruff (0.14.13)

165-165: Boolean-typed positional argument in function definition

(FBT001)

🤖 Prompt for AI Agents
In `@tests/hellinger_distance/test_estimated_hellinger_distance.py` at line 165,
The test function signature test_train_model_and_predict currently allows the
boolean flag verbose to be passed positionally; update the signature to make
boolean flags keyword-only (e.g., def test_train_model_and_predict(device:
Target, model_type: str, *, verbose: bool) -> None) to satisfy Ruff FBT001 and
avoid positional ambiguity; apply the same change to the other test function(s)
with boolean flags reported in the review (the ones around the later tests) so
all boolean parameters are keyword-only.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant