Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:
- name: Download Annoy Index file from Google Drive
run: |
pip install gdown
gdown --id 1gTXqQtP9JS92gAtPzhKgZpdIHSV_Lcnp -O ${{ github.workspace }}/src/sustainml_lib/sustainml_modules/sustainml_modules/sustainml-wp1/rag/models_index.ann
gdown --id 1TQvt1bSXares-I9l7Wki0Jge3oubRkOJ -O ${{ github.workspace }}/src/sustainml_lib/sustainml_modules/sustainml_modules/sustainml-wp1/rag/models_index.ann

- name: Compile and run tests
uses: eProsima/eProsima-CI/multiplatform/colcon_build_test@v0
Expand Down Expand Up @@ -230,7 +230,7 @@ jobs:
- name: Download file from Google Drive
run: |
pip install gdown
gdown --id 1gTXqQtP9JS92gAtPzhKgZpdIHSV_Lcnp -O ${{ github.workspace }}/src/sustainml_lib/sustainml_modules/sustainml_modules/sustainml-wp1/rag/models_index.ann
gdown --id 1TQvt1bSXares-I9l7Wki0Jge3oubRkOJ -O ${{ github.workspace }}/src/sustainml_lib/sustainml_modules/sustainml_modules/sustainml-wp1/rag/models_index.ann

- name: Compile and run tests
uses: eProsima/eProsima-CI/multiplatform/clang_build_test@v0
Expand Down
8 changes: 5 additions & 3 deletions sustainml_docs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ compile_documentation()
# Test
###############################################################################
# Compile tests if CMake options requires it
compile_test_documentation(
"${PROJECT_SOURCE_DIR}/test" # Test directory
)
if (EXISTS "${PROJECT_SOURCE_DIR}/test/CMakeLists.txt")
compile_test_documentation("${PROJECT_SOURCE_DIR}/test")
else()
message(STATUS "[sustainml_docs] No tests directory; skipping test compilation.")
endif()

###############################################################################
# Packaging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,42 @@ And inside ``configuration_callback()`` implement the response to the configurat
# See the License for the specific language governing permissions and
# limitations under the License.
"""SustainML ML Model Provider Node Implementation."""

from sustainml_py.nodes.MLModelNode import MLModelNode

# Manage signaling
import os
import signal
import threading
import time
import json

from rdftool.ModelONNXCodebase import model
from neo4j import GraphDatabase
from rdftool.rdfCode import load_graph, get_models_for_problem, get_models_for_problem_and_tag

from rag.rag_backend import answer_question


# Neo4j config/driver for local checks (used by _model_has_goal)
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "12345678"
neo4j_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))

def _model_has_goal(neo4j_driver, model_name: str, goal: str) -> bool:
cypher = """
MATCH (m:Model {name: $model})-[:HAS_PROBLEM]->(p:Problem)
WHERE toLower(p.name) = toLower($goal)
RETURN COUNT(*) AS cnt
"""
with neo4j_driver.session() as s:
r = s.run(cypher, model=model_name, goal=goal).single()
return bool(r and r["cnt"] > 0)

# Whether to go on spinning or interrupt
running = False


# Load the list of unsupported
def load_unsupported_models(file_path):
try:
Expand All @@ -105,19 +122,19 @@ And inside ``configuration_callback()`` implement the response to the configurat
except Exception as e:
print(f"[WARN] Could not load unsupported list: {e}")
return []


unsupported_models = load_unsupported_models(os.path.dirname(__file__) + "/unsupported_models.txt")


# Signal handler
def signal_handler(sig, frame):
print("\nExiting")
MLModelNode.terminate()
global running
running = False


# User Callback implementation
# Inputs: ml_model_metadata, app_requirements, hw_constraints, ml_model_baseline, hw_baseline, carbonfootprint_baseline
# Outputs: node_status, ml_model
Expand All @@ -129,11 +146,11 @@ And inside ``configuration_callback()`` implement the response to the configurat
carbonfootprint_baseline,
node_status,
ml_model):

# Callback implementation here

print(f"Received Task: {ml_model_metadata.task_id().problem_id()},{ml_model_metadata.task_id().iteration_id()}")

try:
chosen_model = None
# Model restriction after various outputs
Expand All @@ -147,31 +164,70 @@ And inside ``configuration_callback()`` implement the response to the configurat
except json.JSONDecodeError:
print("[WARN] In model_provider node extra_data JSON is not valid.")
extra_data_dict = {}

if "type" in extra_data_dict:
type = extra_data_dict["type"]

if "model_restrains" in extra_data_dict:
restrained_models = extra_data_dict["model_restrains"]

if "model_selected" in extra_data_dict:
chosen_model = extra_data_dict["model_selected"]
print("Model already selected: ", chosen_model)

problem_short_description = extra_data_dict["problem_short_description"]

metadata = ml_model_metadata.ml_model_metadata()[0]

if chosen_model is None:
print(f"Problem short description: {problem_short_description}")

# Choose model with the RAG based on the goal selected and the knowledge of the graph.
chosen_model = answer_question(
f"Task {metadata} with problem description: {problem_short_description}?"
)


goal = ml_model_metadata.ml_model_metadata()[0] # goal selected by metadata node
print(f"Problem short description: {problem_short_description}")
print(f"Selected goal (metadata): {goal}")

# Build strictly goal-scoped allowed list (names only)
goal_models = get_models_for_problem(goal) # [(model_name, downloads), ...]
allowed_names = [name for (name, _) in goal_models]
print(f"[INFO] {len(allowed_names)} candidates for goal '{goal}'")
if not allowed_names:
raise Exception("No candidates in graph for the selected goal")

# Track models to avoid repeats across outputs
restrained_models = []
if extra_data_bytes:
try:
if "model_restrains" in extra_data_dict:
restrained_models = list(set(extra_data_dict["model_restrains"]))
except Exception:
pass

# Try up to 10 candidates, skipping misfits transparently
chosen_model = None
for _ in range(10):
remaining = [n for n in allowed_names if n not in restrained_models]
if not remaining:
break

candidate = answer_question(
f"Task {goal} with problem description: {problem_short_description}?",
allowed_models=remaining
)

if not candidate or candidate.strip().lower() == "none":
# mark and try again
if candidate:
restrained_models.append(candidate)
continue

# Final safety: ensure candidate really belongs to goal
if not _model_has_goal(neo4j_driver, candidate, goal):
print(f"[GUARD] Dropping {candidate}: not linked to goal {goal}")
restrained_models.append(candidate)
continue

chosen_model = candidate
break

if not chosen_model:
raise Exception("No suitable model after screening candidates")
print(f"ML Model chosen: {chosen_model}")

# Generate model code and keywords
onnx_path = model(chosen_model) # TODO - Further development needed
ml_model.model(chosen_model)
Expand All @@ -180,27 +236,27 @@ And inside ``configuration_callback()`` implement the response to the configurat
extra_data = {"unsupported_models": unsupported_models}
encoded_data = json.dumps(extra_data).encode("utf-8")
ml_model.extra_data(encoded_data)

except Exception as e:
print(f"Failed to determine ML model for task {ml_model_metadata.task_id()}: {e}.")
ml_model.model("Error")
ml_model.model_path("Error")
error_message = "Failed to obtain ML model for task: " + str(e)
error_info = {"error": error_message}
print(f"[WARN] No suitable model found for task {ml_model_metadata.task_id()}: {e}")
ml_model.model("NO_MODEL")
ml_model.model_path("N/A")
error_message = "No suitable model found for the given problem."
error_info = {"error_code": "NO_MODEL", "error": error_message}
encoded_error = json.dumps(error_info).encode("utf-8")
ml_model.extra_data(encoded_error)


# User Configuration Callback implementation
# Inputs: req
# Outputs: res
def configuration_callback(req, res):

# Callback for configuration implementation here
if 'model_from_goal' in req.configuration():
res.node_id(req.node_id())
res.transaction_id(req.transaction_id())

try:
text = req.configuration()[len("model_from_goal, "):]
parts = text.split(',')
Expand All @@ -211,24 +267,24 @@ And inside ``configuration_callback()`` implement the response to the configurat
else:
goal = text.strip()
models = get_models_for_problem(goal)

sorted_models = ', '.join(sorted([str(m[0]) for m in models]))

if not sorted_models:
res.success(False)
res.err_code(1) # 0: No error || 1: Error
else:
res.success(True)
res.err_code(0) # 0: No error || 1: Error

print(f"Models for {goal}: {sorted_models}") # debug
res.configuration(json.dumps(dict(models=sorted_models)))

except Exception as e:
print(f"Error getting model from goal from request: {e}")
res.success(False)
res.err_code(1)

else:
res.node_id(req.node_id())
res.transaction_id(req.transaction_id())
Expand All @@ -237,8 +293,8 @@ And inside ``configuration_callback()`` implement the response to the configurat
res.success(False)
res.err_code(1) # 0: No error || 1: Error
print(error_msg)


# Main workflow routine
def run():
start_time = time.time()
Expand All @@ -255,19 +311,19 @@ And inside ``configuration_callback()`` implement the response to the configurat
global running
running = True
node.spin()


# Call main in program execution
if __name__ == '__main__':
signal.signal(signal.SIGINT, signal_handler)

"""Python does not process signals async if
the main thread is blocked (spin()) so, tun
user work flow in another thread """
runner = threading.Thread(target=run)
runner.start()

while running:
time.sleep(1)

runner.join()
8 changes: 4 additions & 4 deletions sustainml_docs/rst/installation/framework.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ Install them using the package manager of the appropriate Linux distribution.

brew install \
curl wget git llvm cmake gcc python swig [email protected] asio tinyxml2 libp11 softhsm qt@5 openjdk@21 neo4j

# Add Java 21 to the PATH (bashrc, zshrc or fish config depending on your shell)
echo 'export JAVA_HOME=$(/usr/libexec/java_home -v 21)' >> ~/.zshrc
echo 'export PATH=$JAVA_HOME/bin:$PATH' >> ~/.zshrc
source ~/.zshrc

# Start Neo4j as a service
brew services start neo4j

Expand All @@ -99,7 +99,7 @@ The following command also builds and installs the SustainML framework and all i
vcs import src < sustainml.repos && cd ~/SustainML/SustainML_ws/src/sustainml_lib && \
git submodule update --init --recursive && \
cd ~/SustainML/SustainML_ws/src/sustainml_lib/sustainml_modules/sustainml_modules/sustainml-wp1 && \
gdown --id 1gTXqQtP9JS92gAtPzhKgZpdIHSV_Lcnp -O rag/models_index.ann && \
gdown --id 1TQvt1bSXares-I9l7Wki0Jge3oubRkOJ -O rag/models_index.ann && \
pip3 install -r ~/SustainML/SustainML_ws/src/sustainml_lib/sustainml_modules/requirements.txt && \
cd ~/SustainML/SustainML_ws && colcon build && \
source ~/SustainML/SustainML_ws/install/setup.bash && \
Expand All @@ -123,7 +123,7 @@ The following command also builds and installs the SustainML framework and all i
vcs import src < sustainml.repos && cd ~/SustainML/SustainML_ws/src/sustainml_lib && \
git submodule update --init --recursive && \
cd ~/SustainML/SustainML_ws/src/sustainml_lib/sustainml_modules/sustainml_modules/sustainml-wp1 && \
gdown --id 1gTXqQtP9JS92gAtPzhKgZpdIHSV_Lcnp -O rag/models_index.ann && \
gdown --id 1TQvt1bSXares-I9l7Wki0Jge3oubRkOJ -O rag/models_index.ann && \
pip3 install -r ~/SustainML/SustainML_ws/src/sustainml_lib/sustainml_modules/requirements.txt && \
cd ~/SustainML/SustainML_ws && colcon build --packages-up-to sustainml --cmake-args -DCMAKE_CXX_STANDARD=17 \
-DQt5_DIR=/usr/local/opt/qt5/lib/cmake/Qt5 && \
Expand Down
8 changes: 4 additions & 4 deletions sustainml_docs/rst/installation/library.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ Install them using the package manager of the appropriate OS distribution.

brew install \
wget git llvm cmake gcc python swig [email protected] asio tinyxml2 libp11 softhsm openjdk@21 neo4j

# Add Java 21 to the PATH (bashrc, zshrc or fish config depending on your shell)
echo 'export JAVA_HOME=$(/usr/libexec/java_home -v 21)' >> ~/.zshrc
echo 'export PATH=$JAVA_HOME/bin:$PATH' >> ~/.zshrc
source ~/.zshrc


.. _installation_library_build:

Expand All @@ -90,7 +90,7 @@ The following command builds and installs the *SustainML library* and its depend
vcs import src < sustainml.repos && \
git submodule update --init --recursive && \
cd ~/SustainML/SustainML_ws/src/sustainml_lib/sustainml_modules/sustainml_modules/sustainml-wp1 && \
gdown --id 1gTXqQtP9JS92gAtPzhKgZpdIHSV_Lcnp -O rag/models_index.ann && \
gdown --id 1TQvt1bSXares-I9l7Wki0Jge3oubRkOJ -O rag/models_index.ann && \
pip3 install -r ~/SustainML/SustainML_ws/src/sustainml_docs/requirements.txt && \
colcon build && \
sudo neo4j-admin database load system \
Expand All @@ -112,7 +112,7 @@ The following command builds and installs the *SustainML library* and its depend
vcs import src < sustainml.repos && \
git submodule update --init --recursive && \
cd ~/SustainML/SustainML_ws/src/sustainml_lib/sustainml_modules/sustainml_modules/sustainml-wp1 && \
gdown --id 1gTXqQtP9JS92gAtPzhKgZpdIHSV_Lcnp -O rag/models_index.ann && \
gdown --id 1TQvt1bSXares-I9l7Wki0Jge3oubRkOJ -O rag/models_index.ann && \
pip3 install -r ~/SustainML/SustainML_ws/src/sustainml_docs/requirements.txt && \
colcon build --cmake-args -DCMAKE_CXX_STANDARD=17 && \
sudo neo4j-admin database load system \
Expand Down
4 changes: 4 additions & 0 deletions sustainml_modules/test/communication/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,8 @@ if (Python3_FOUND)
-ap app_requirements_node.py
-py-orc RequestOrchestratorNode.py)

set_tests_properties(SimpleCommunicationOrchestratorSixNodesPython PROPERTIES TIMEOUT 600)
set_tests_properties(SimpleCommunicationPythonOrchestratorSixNodesPython PROPERTIES TIMEOUT 600)
set_tests_properties(SimpleServicePythonOrchestratorSixNodesPython PROPERTIES TIMEOUT 600)

endif()
Loading