Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 104 additions & 129 deletions src/uckn/core/atoms/multi_modal_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,61 @@

import hashlib
import logging
import os
import threading
from typing import Any

import numpy as np

from ..ml_environment_manager import get_ml_manager
# Defensive import logic for torch and sentence-transformers
SENTENCE_TRANSFORMERS_AVAILABLE = False
TRANSFORMERS_AVAILABLE = False
SentenceTransformer = None
AutoTokenizer = None
AutoModel = None
torch = None

_DISABLE_TORCH = os.environ.get("UCKN_DISABLE_TORCH", "0") == "1"

if not _DISABLE_TORCH:
# Try importing torch and transformers defensively
try:
try:
import torch
except Exception:
torch = None # type: ignore[assignment]
# Log or print for debugging, but do not raise
else:
try:
from transformers import AutoModel, AutoTokenizer

TRANSFORMERS_AVAILABLE = True
except Exception:
AutoTokenizer = None # type: ignore[assignment]
AutoModel = None # type: ignore[assignment]
TRANSFORMERS_AVAILABLE = False
except Exception:
torch = None
AutoTokenizer = None
AutoModel = None
TRANSFORMERS_AVAILABLE = False

# Try importing sentence-transformers defensively
try:
from sentence_transformers import SentenceTransformer

SENTENCE_TRANSFORMERS_AVAILABLE = True
except Exception:
SentenceTransformer = None
SENTENCE_TRANSFORMERS_AVAILABLE = False
else:
# Torch is disabled by environment variable
torch = None
AutoTokenizer = None
AutoModel = None
SentenceTransformer = None
TRANSFORMERS_AVAILABLE = False
SENTENCE_TRANSFORMERS_AVAILABLE = False


class MultiModalEmbeddings:
Expand All @@ -31,25 +80,28 @@ class MultiModalEmbeddings:

def __init__(self, device: str | None = None):
self._logger = logging.getLogger(__name__)
self._ml_manager = get_ml_manager()

# Use ML manager to determine device
self.device = device or self._ml_manager.get_device()
# Defensive: If torch is unavailable, always use cpu
self.device: str = "cpu" # Default value
if (
torch is not None
and hasattr(torch, "cuda")
and callable(getattr(torch.cuda, "is_available", None))
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self._lock = threading.Lock()

# Model loading
self.code_tokenizer = None
self.code_model = None
self.text_model = None

# Initialize models based on environment capabilities
if self._ml_manager.should_use_real_ml():
# Only initialize models if not disabled
if not _DISABLE_TORCH:
self._init_code_model()
self._init_text_model()
else:
env_info = self._ml_manager.get_environment_info()
self._logger.info(
f"Using fallback embeddings - Environment: {env_info['environment']}"
self._logger.warning(
"Torch and transformers are disabled by environment variable."
)

# In-memory cache for embeddings
Expand All @@ -62,130 +114,56 @@ def is_available(self) -> bool:
Returns:
bool: True if at least one embedding model is available, False otherwise.
"""
# Component is always available - either real ML or fallbacks
caps = self._ml_manager.capabilities

has_real_models = (
caps.sentence_transformers and self.text_model is not None
) or (
caps.transformers
# Component is available if at least one model is initialized
# or if we have the basic dependencies available
has_text_model = SENTENCE_TRANSFORMERS_AVAILABLE and self.text_model is not None
has_code_model = (
TRANSFORMERS_AVAILABLE
and self.code_model is not None
and self.code_tokenizer is not None
)

# Always available: either real models or fallback embeddings
return has_real_models or caps.fallback_embeddings

def _generate_fake_embedding(self, text: str, dim: int = 384) -> list[float]:
"""Generate deterministic fake embedding for testing when ML models unavailable."""
import hashlib
import re

# Extract words for semantic features
words = set(re.findall(r"\w+", text.lower()))

# Create word-based features for first part of embedding
word_features = []
common_words = {
"add",
"sum",
"two",
"numbers",
"values",
"def",
"function",
"class",
"setting",
"config",
"error",
"exception",
"true",
"false",
"return",
"division",
"zero",
"traceback",
"zerodivisionerror",
"by",
}

for common_word in sorted(common_words):
if common_word in words:
word_features.append(1.0)
else:
word_features.append(0.0)

# Pad or truncate to half the dimension
half_dim = dim // 2
while len(word_features) < half_dim:
word_features.append(0.0)
word_features = word_features[:half_dim]

# Create hash-based features for second half
hash_obj = hashlib.md5(text.encode(), usedforsecurity=False)
hash_bytes = hash_obj.digest()
hash_features = []

for i in range(dim - half_dim):
byte_val = hash_bytes[i % len(hash_bytes)]
# Smaller range for hash features to reduce noise
norm_val = (byte_val / 255.0) * 0.2 - 0.1
hash_features.append(norm_val)

# Combine features
embedding = word_features + hash_features

# Normalize to unit vector
norm = sum(x**2 for x in embedding) ** 0.5
if norm > 0:
embedding = [x / norm for x in embedding]

return embedding
# Available if we have at least one working model or basic dependencies
return has_text_model or has_code_model or SENTENCE_TRANSFORMERS_AVAILABLE

def _init_code_model(self):
if not self._ml_manager.capabilities.transformers:
self._logger.debug(
if (
not TRANSFORMERS_AVAILABLE
or AutoTokenizer is None
or AutoModel is None
or torch is None
):
self._logger.warning(
"Transformers not available. Code embedding will fallback to text model."
)
return

try:
self.code_model, self.code_tokenizer = (
self._ml_manager.get_transformers_model(self._CODE_MODEL_NAME)
self.code_tokenizer = AutoTokenizer.from_pretrained(self._CODE_MODEL_NAME)
self.code_model = AutoModel.from_pretrained(self._CODE_MODEL_NAME).to(
self.device
)
if self.code_model and self.code_tokenizer:
self._logger.info(f"Loaded code model: {self._CODE_MODEL_NAME}")
else:
self._logger.warning(
f"Failed to load code model '{self._CODE_MODEL_NAME}'. Falling back to text model."
)
self._logger.info(f"Loaded code model: {self._CODE_MODEL_NAME}")
except Exception as e:
self._logger.warning(
f"Error loading code model '{self._CODE_MODEL_NAME}': {e}. Falling back to text model."
f"Failed to load code model '{self._CODE_MODEL_NAME}': {e}. Falling back to text model."
)
self.code_tokenizer = None
self.code_model = None

def _init_text_model(self):
if not self._ml_manager.capabilities.sentence_transformers:
self._logger.debug(
"SentenceTransformers not available. Text embedding will use fallbacks."
if not SENTENCE_TRANSFORMERS_AVAILABLE or SentenceTransformer is None:
self._logger.warning(
"SentenceTransformers not available. Text embedding will be disabled."
)
return

try:
self.text_model = self._ml_manager.get_sentence_transformer(
self._TEXT_MODEL_NAME
self.text_model = SentenceTransformer(
self._TEXT_MODEL_NAME, device=self.device
)
if self.text_model:
self._logger.info(f"Loaded text model: {self._TEXT_MODEL_NAME}")
else:
self._logger.warning(
f"Failed to load text model '{self._TEXT_MODEL_NAME}'. Using fallbacks."
)
self._logger.info(f"Loaded text model: {self._TEXT_MODEL_NAME}")
except Exception as e:
self._logger.warning(
f"Error loading text model '{self._TEXT_MODEL_NAME}': {e}. Using fallbacks."
self._logger.error(
f"Failed to load text model '{self._TEXT_MODEL_NAME}': {e}"
)
self.text_model = None

Expand All @@ -207,17 +185,12 @@ def _embed_code(self, code: str) -> list[float] | None:
cached = self._get_cached_embedding(key)
if cached:
return cached
if (
self.code_model
and self.code_tokenizer
and self._ml_manager.capabilities.torch
):
if self.code_model and self.code_tokenizer and torch is not None:
try:
inputs = self.code_tokenizer(
code, return_tensors="pt", truncation=True, max_length=256
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
torch = self._ml_manager._get_import("torch")
with torch.no_grad():
outputs = self.code_model(**inputs)
# Use [CLS] token representation
Expand Down Expand Up @@ -248,9 +221,7 @@ def _embed_text(self, text: str) -> list[float] | None:
return embedding
except Exception as e:
self._logger.error(f"Text embedding failed: {e}")

# Fallback: Generate deterministic fake embedding for testing
return self._generate_fake_embedding(text)
return None

def _embed_config(self, config: str) -> list[float] | None:
# Simple tokenization: split on newlines, colons, equals, etc.
Expand Down Expand Up @@ -284,17 +255,21 @@ def embed(
data_type = data["type"]
data = data["content"]

# Ensure data is a string at this point
if not isinstance(data, str):
self._logger.warning(
f"Expected string data, got {type(data)}. Converting to string."
)
data = str(data)

if data_type == "auto":
# Heuristic: detect type
if isinstance(data, str):
if data.strip().startswith("def ") or data.strip().startswith("class "):
data_type = "code"
elif "=" in data and "\n" in data:
data_type = "config"
elif "Traceback" in data or "Exception" in data:
data_type = "error"
else:
data_type = "text"
if data.strip().startswith("def ") or data.strip().startswith("class "):
data_type = "code"
elif "=" in data and "\n" in data:
data_type = "config"
elif "Traceback" in data or "Exception" in data:
data_type = "error"
else:
data_type = "text"

Expand Down
Loading