Skip to content

Conversation

@fusawa-yugo
Copy link
Contributor

@fusawa-yugo fusawa-yugo commented Dec 12, 2025

Contributor Agreements

Please read the contributor agreements and if you agree, please click the checkbox below.

  • I agree to the contributor agreements.

Tip

Please follow the Quick TODO list to smoothly merge your PR.

Motivation

I created a optuna wrapper for SafeCMA in cmaes.

Description of the changes

I implemented SafeCMASampler under package/sampler.

TODO List towards PR Merge

Please remove this section if this PR is not an addition of a new package.
Otherwise, please check the following TODO list:

  • Copy ./template/ to create your package
  • Replace <COPYRIGHT HOLDER> in LICENSE of your package with your name
  • Fill out README.md in your package
  • Add import statements of your function or class names to be used in __init__.py
  • (Optional) Add from __future__ import annotations at the head of any Python files that include typing to support older Python versions
  • Apply the formatter based on the tips in README.md
  • Check whether your module works as intended based on the tips in README.md

@fusawa-yugo fusawa-yugo marked this pull request as draft December 12, 2025 08:07
@fusawa-yugo fusawa-yugo marked this pull request as ready for review December 16, 2025 05:38
@y0z
Copy link
Member

y0z commented Dec 17, 2025

@sawa3030 Could you review this PR?

@y0z y0z added the new-package New packages label Dec 17, 2025
@fusawa-yugo
Copy link
Contributor Author

fusawa-yugo commented Dec 18, 2025

Here is a test code to see how it works.

code
import optuna
import optunahub
import numpy as np
from typing import Sequence
import matplotlib.pyplot as plt

dim = 2
seed = 42
np.random.seed(seed)

safe_seeds_num = 10
bounds = np.array([[-5, 5]] * dim)

def objective_function(x: Sequence[float]) -> float:
    return sum(x[i]**2 for i in range(dim))

def safe_function_1(x: Sequence[float]) -> float:
    return x[0]

def safe_function_2(x: Sequence[float]) -> float:
    return - sum(x[i] for i in range(dim))


def objective(trial: optuna.Trial) -> float:
    x = [trial.suggest_float(f"x{i}", bounds[i, 0], bounds[i, 1]) for i in range(dim)]
    return objective_function(x)

safety_threshold_1 = [0.0]
safe_seeds_1 = []
while len(safe_seeds_1) < safe_seeds_num:
    x = np.random.rand(dim) * (bounds[:, 1] - bounds[:, 0]) + bounds[:, 0]
    if safe_function_1(x) <= safety_threshold_1[0]:
        safe_seeds_1.append(x)

seeds_evals_1 = [objective_function(x) for x in safe_seeds_1]
seeds_safe_evals_1 = [[safe_function_1(x)] for x in safe_seeds_1]

safety_threshold_2 = [-1.0]
safe_seeds_2 = []
while len(safe_seeds_2) < safe_seeds_num:
    x = np.random.rand(dim) * (bounds[:, 1] - bounds[:, 0]) + bounds[:, 0]
    if safe_function_2(x) <= safety_threshold_2[0]:
        safe_seeds_2.append(x)

seeds_evals_2 = [objective_function(x) for x in safe_seeds_2]
seeds_safe_evals_2 = [[safe_function_2(x)] for x in safe_seeds_2]

sampler_1 = optunahub.load_local_module(
    package="samplers/safe_cma",
    registry_root="../../optunahub-registry/package"
).SafeCMASampler(
    safe_seeds=safe_seeds_1,
    seeds_evals=seeds_evals_1,
    seeds_safe_evals=seeds_safe_evals_1,
    safety_threshold=safety_threshold_1,
    safe_function=safe_function_1,
    seed=seed,
)

sampler_2 = optunahub.load_local_module(
    package="samplers/safe_cma",
    registry_root="../../optunahub-registry/package"
).SafeCMASampler(
    safe_seeds=safe_seeds_2,
    seeds_evals=seeds_evals_2,
    seeds_safe_evals=seeds_safe_evals_2,
    safety_threshold=safety_threshold_2,
    safe_function=safe_function_2,
    seed=seed,
)


def plot_results(study):
    trials = study.get_trials()
    values = np.array([trial.value for trial in trials])
    x0 = np.array([trial.params["x0"] for trial in trials])
    x1 = np.array([trial.params["x1"] for trial in trials])
    trial_numbers = np.array([trial.number for trial in trials])
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Left plot: values transition
    ax1.plot(values)
    ax1.set_xlabel("Trial")
    ax1.set_ylabel("Objective Value")
    ax1.set_title(f"Optimization History - {study.study_name}")
    ax1.grid(True)
    
    # Right plot: x0, x1 scatter
    scatter = ax2.scatter(x0, x1, c=trial_numbers, cmap="viridis", alpha=0.6)
    ax2.set_xlabel("x0")
    ax2.set_ylabel("x1")
    ax2.set_title(f"Parameter Space - {study.study_name}")
    ax2.grid(True)
    plt.colorbar(scatter, ax=ax2, label="Trial Number")
    
    plt.tight_layout()
    return fig

study_1 = optuna.create_study(sampler=sampler_1, study_name="x[0] <= 0")
study_2 = optuna.create_study(sampler=sampler_2, study_name="sum(x) >= 1")

study_1.optimize(objective, n_trials=500)
study_2.optimize(objective, n_trials=500)

plot_results(study_1)
plot_results(study_2)

plt.show()

results

Points are likely to be sampled within the safe region.

  • study_1
aaaaafavs
  • study_2
cnoahoiah

Copy link
Collaborator

@sawa3030 sawa3030 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few comments. PTAL!

@fusawa-yugo
Copy link
Contributor Author

@sawa3030
Thank you for your comments! Your suggestions were very helpful.
I’ve made updates. Could you take another look?

Copy link
Collaborator

@sawa3030 sawa3030 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay in getting back to you. I have a quick suggestion.

Copy link
Collaborator

@sawa3030 sawa3030 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for all the updates! LGTM

@sawa3030 sawa3030 merged commit 6711116 into optuna:main Jan 16, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

new-package New packages

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants