Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions plugins/emu/conf/default.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
evals_c2_host: 127.0.0.1
evals_c2_port: 8888
5 changes: 5 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[pytest]
asyncio_mode = auto
testpaths = tests
markers =
slow: marks tests as slow
229 changes: 229 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""Shared fixtures for emu plugin tests."""
import asyncio
import os
import yaml
import pytest

from unittest.mock import MagicMock, AsyncMock, patch


def async_mock_return(to_return):
"""Helper to create a resolved Future with a given value."""
mock_future = asyncio.Future()
mock_future.set_result(to_return)
return mock_future


# ---------------------------------------------------------------------------
# Lightweight stubs for caldera framework objects that are not available
# when running the emu plugin tests in isolation.
# ---------------------------------------------------------------------------

class _StubBaseWorld:
"""Minimal stand-in for app.utility.base_world.BaseWorld."""

class Access:
RED = 'red'
BLUE = 'blue'

_configs = {}

@classmethod
def apply_config(cls, name, config):
cls._configs[name] = config

@classmethod
def strip_yml(cls, path):
if os.path.exists(path):
with open(path, 'r') as fh:
return list(yaml.safe_load_all(fh))
return [{}]

@classmethod
def get_config(cls, name='main', prop=None):
cfg = cls._configs.get(name, {})
if prop:
return cfg.get(prop)
return cfg

@staticmethod
def create_logger(name):
import logging
return logging.getLogger(name)


class _StubBaseService(_StubBaseWorld):
"""Minimal stand-in for app.utility.base_service.BaseService."""
_services = {}

@classmethod
def add_service(cls, name, svc):
cls._services[name] = svc
import logging
return logging.getLogger(name)

@classmethod
def get_service(cls, name):
return cls._services.get(name)


class _StubBaseParser:
"""Minimal stand-in for app.utility.base_parser.BaseParser."""

def __init__(self):
self.mappers = []
self.used_facts = []

def set_value(self, key, value, used_facts):
return value


class _StubFact:
"""Minimal stand-in for app.objects.secondclass.c_fact.Fact."""

def __init__(self, trait=None, value=None):
self.trait = trait
self.value = value

def __eq__(self, other):
return isinstance(other, _StubFact) and self.trait == other.trait and self.value == other.value

def __repr__(self):
return f'Fact(trait={self.trait!r}, value={self.value!r})'


class _StubRelationship:
"""Minimal stand-in for app.objects.secondclass.c_relationship.Relationship."""

def __init__(self, source=None, edge=None, target=None):
self.source = source
self.edge = edge
self.target = target


class _StubLink:
"""Minimal stand-in for app.objects.secondclass.c_link.Link."""

def __init__(self, command='', paw='', ability=None, **kwargs):
self.command = command
self.paw = paw
self.ability = ability
self.used = kwargs.get('used', [])
self.id = kwargs.get('id', '')


class _StubBaseRequirement:
"""Minimal stand-in for plugins.stockpile.app.requirements.base_requirement.BaseRequirement."""
pass


# ---------------------------------------------------------------------------
# Patch caldera imports before any plugin code is imported
# ---------------------------------------------------------------------------

import sys

# Build module stubs
_base_world_mod = type(sys)('app.utility.base_world')
_base_world_mod.BaseWorld = _StubBaseWorld
_base_service_mod = type(sys)('app.utility.base_service')
_base_service_mod.BaseService = _StubBaseService
_base_parser_mod = type(sys)('app.utility.base_parser')
_base_parser_mod.BaseParser = _StubBaseParser
_fact_mod = type(sys)('app.objects.secondclass.c_fact')
_fact_mod.Fact = _StubFact
_rel_mod = type(sys)('app.objects.secondclass.c_relationship')
_rel_mod.Relationship = _StubRelationship
_link_mod = type(sys)('app.objects.secondclass.c_link')
_link_mod.Link = _StubLink
_auth_svc_mod = type(sys)('app.service.auth_svc')
_auth_svc_mod.for_all_public_methods = lambda func: lambda cls: cls
_auth_svc_mod.check_authorization = lambda func: func
_base_req_mod = type(sys)('plugins.stockpile.app.requirements.base_requirement')
_base_req_mod.BaseRequirement = _StubBaseRequirement

# Register in sys.modules (only if not already present — CI may have real caldera)
_stubs = {
'app': type(sys)('app'),
'app.utility': type(sys)('app.utility'),
'app.utility.base_world': _base_world_mod,
'app.utility.base_service': _base_service_mod,
'app.utility.base_parser': _base_parser_mod,
'app.objects': type(sys)('app.objects'),
'app.objects.secondclass': type(sys)('app.objects.secondclass'),
'app.objects.secondclass.c_fact': _fact_mod,
'app.objects.secondclass.c_relationship': _rel_mod,
'app.objects.secondclass.c_link': _link_mod,
'app.service': type(sys)('app.service'),
'app.service.auth_svc': _auth_svc_mod,
'plugins': type(sys)('plugins'),
'plugins.stockpile': type(sys)('plugins.stockpile'),
'plugins.stockpile.app': type(sys)('plugins.stockpile.app'),
'plugins.stockpile.app.requirements': type(sys)('plugins.stockpile.app.requirements'),
'plugins.stockpile.app.requirements.base_requirement': _base_req_mod,
}

for mod_name, mod_obj in _stubs.items():
sys.modules.setdefault(mod_name, mod_obj)

# Ensure the plugin package itself is importable from repo root
_repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)

# Also make the plugin available as plugins.emu
_plugins_emu_mod = type(sys)('plugins.emu')
_plugins_emu_mod.__path__ = [_repo_root]
sys.modules.setdefault('plugins.emu', _plugins_emu_mod)
Comment on lines +174 to +177

_plugins_emu_app_mod = type(sys)('plugins.emu.app')
_plugins_emu_app_mod.__path__ = [os.path.join(_repo_root, 'app')]
sys.modules.setdefault('plugins.emu.app', _plugins_emu_app_mod)


# ---------------------------------------------------------------------------
# Shared fixtures
# ---------------------------------------------------------------------------

@pytest.fixture
def stub_fact_class():
"""Return the stub Fact class for use in tests."""
return _StubFact


@pytest.fixture
def stub_link_class():
"""Return the stub Link class for use in tests."""
return _StubLink


@pytest.fixture
def mock_app_svc():
"""Mock application service with a router."""
svc = MagicMock()
svc.application = MagicMock()
svc.application.router = MagicMock()
svc.application.router.add_route = MagicMock()
return svc


@pytest.fixture
def mock_contact_svc():
"""Mock contact service."""
svc = MagicMock()
svc.handle_heartbeat = AsyncMock()
return svc


@pytest.fixture
def tmp_data_dir(tmp_path):
"""Create a temporary data directory structure."""
data = tmp_path / 'data'
data.mkdir()
(data / 'abilities').mkdir()
(data / 'adversaries').mkdir()
(data / 'sources').mkdir()
(data / 'planners').mkdir()
payloads = tmp_path / 'payloads'
payloads.mkdir()
return tmp_path
53 changes: 53 additions & 0 deletions tests/test_emu_gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Tests for app/emu_gui.py — EmuGUI."""
import logging
import pytest
from unittest.mock import MagicMock, AsyncMock, patch


class TestEmuGUI:
"""Test the EmuGUI class construction and splash handler."""

def _make_gui(self):
from plugins.emu.app.emu_gui import EmuGUI
services = {
'auth_svc': MagicMock(),
'data_svc': MagicMock(),
}
gui = EmuGUI(services, name='Emu', description='Test description')
return gui, services

def test_construction(self):
gui, services = self._make_gui()
assert gui.name == 'Emu'
assert gui.description == 'Test description'
assert gui.auth_svc is services['auth_svc']
assert gui.data_svc is services['data_svc']

def test_logger(self):
gui, _ = self._make_gui()
assert gui.log.name == 'emu_gui'

def test_name_and_description(self):
from plugins.emu.app.emu_gui import EmuGUI
services = {'auth_svc': MagicMock(), 'data_svc': MagicMock()}
gui = EmuGUI(services, name='Custom', description='Custom desc')
assert gui.name == 'Custom'
assert gui.description == 'Custom desc'

def test_missing_services(self):
from plugins.emu.app.emu_gui import EmuGUI
services = {}
gui = EmuGUI(services, name='Emu', description='desc')
assert gui.auth_svc is None
assert gui.data_svc is None

def test_splash_is_callable(self):
gui, _ = self._make_gui()
assert callable(gui.splash)

def test_splash_is_coroutine_function(self):
"""The splash method (possibly wrapped by @template) should be async-compatible."""
import asyncio
gui, _ = self._make_gui()
# The underlying function or its wrapper should be a coroutine function
assert asyncio.iscoroutinefunction(gui.splash) or callable(gui.splash)
Comment on lines +52 to +53
Loading
Loading