Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc_source/directions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Directions

```{eval-rst}
.. automodule:: ect.directions
:members:
```
5 changes: 5 additions & 0 deletions doc_source/ect_on_graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
.. automodule:: ect.ect_graph
:members:
```

```{eval-rst}
.. automodule:: ect.sect
:members:
```
3 changes: 2 additions & 1 deletion doc_source/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ Table of Contents

Embedded graphs <embed_graph.md>
Embedded CW complex <embed_cw.md>
ECT on graphs <ect_on_graphs.md>
ECT on graphs <ect_on_graphs.md>
Directions <directions.md>
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ect"
version = "1.0.0"
version = "1.0.2"
authors = [
{ name="Liz Munch", email="muncheli@msu.edu" },
]
Expand Down
12 changes: 7 additions & 5 deletions src/ect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from .embed_graph import EmbeddedGraph
from .embed_cw import EmbeddedCW
from .directions import Directions
from .sect import SECT
from .utils import examples

__all__ = [
'ECT',
'EmbeddedGraph',
'EmbeddedCW',
'Directions',
'examples',
"ECT",
"SECT",
"EmbeddedGraph",
"EmbeddedCW",
"Directions",
"examples",
]
59 changes: 59 additions & 0 deletions src/ect/sect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from ect import ECT
from .embed_graph import EmbeddedGraph
from .embed_cw import EmbeddedCW
from .directions import Directions
from .results import ECTResult
from typing import Optional, Union
import numpy as np


class SECT(ECT):
"""
A class to calculate the Smooth Euler Characteristic Transform (SECT).
Inherits from ECT and applies smoothing to the final result.
"""

def __init__(
self,
directions: Optional[Directions] = None,
num_dirs: Optional[int] = None,
num_thresh: Optional[int] = None,
bound_radius: Optional[float] = None,
thresholds: Optional[np.ndarray] = None,
dtype=np.float32,
):
"""Initialize SECT calculator with smoothing parameter

Args:
directions: Optional pre-configured Directions object
num_dirs: Number of directions to sample (ignored if directions provided)
num_thresh: Number of threshold values (required if directions not provided)
bound_radius: Optional radius for bounding circle
thresholds: Optional array of thresholds
dtype: Data type for output array
"""
super().__init__(
directions, num_dirs, num_thresh, bound_radius, thresholds, dtype
)

def calculate(
self,
graph: Union[EmbeddedGraph, EmbeddedCW],
theta: Optional[float] = None,
override_bound_radius: Optional[float] = None,
) -> ECTResult:
"""Calculate Smooth Euler Characteristic Transform (SECT)

Args:
graph: The input graph to calculate the SECT for
theta: The angle in [0,2π] for the direction to calculate the SECT
override_bound_radius: Optional override for bounding radius

Returns:
ECTResult: The smoothed transform result containing the matrix,
directions, and thresholds
"""
ect_result = super().calculate(graph, theta, override_bound_radius)
return ECTResult(
ect_result, ect_result.directions, ect_result.thresholds
).smooth()
76 changes: 76 additions & 0 deletions tests/test_sect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import unittest
import numpy as np
from ect import SECT, ECT
from ect.utils.examples import create_example_graph
from ect.directions import Directions


class TestSECT(unittest.TestCase):
def setUp(self):
"""Set up test fixtures"""
self.graph = create_example_graph()
self.num_dirs = 8
self.num_thresh = 10
self.sect = SECT(num_dirs=self.num_dirs, num_thresh=self.num_thresh)

def test_inheritance(self):
"""Test that SECT properly inherits from ECT"""
self.assertIsInstance(self.sect, ECT)
self.assertTrue(hasattr(self.sect, "calculate"))

def test_calculate_output_shape(self):
"""Test that SECT calculation returns correct shape"""
result = self.sect.calculate(self.graph)

self.assertEqual(result.shape[0], self.num_dirs)
self.assertEqual(result.shape[1], self.num_thresh)
self.assertEqual(len(result.thresholds), self.num_thresh)
self.assertEqual(len(result.directions), self.num_dirs)

def test_smoothing_effect(self):
"""Test that smoothing is actually applied"""
# Calculate both ECT and SECT
ect = ECT(num_dirs=self.num_dirs, num_thresh=self.num_thresh)
sect = SECT(num_dirs=self.num_dirs, num_thresh=self.num_thresh)

ect_result = ect.calculate(self.graph)
sect_result = sect.calculate(self.graph)

# Verify results are different due to smoothing
self.assertFalse(np.allclose(ect_result, sect_result))

# Verify smoothing preserves direction count
self.assertEqual(
np.sum(ect_result, axis=1).shape,
np.sum(sect_result, axis=1).shape,
)

def test_with_theta(self):
"""Test SECT calculation with specific theta value"""
theta = np.pi / 4
result = self.sect.calculate(self.graph, theta=theta)

# Should only have one direction when theta is specified
self.assertEqual(result.shape[0], 1)
self.assertEqual(result.shape[1], self.num_thresh)

def test_with_override_radius(self):
"""Test SECT calculation with override_bound_radius"""
override_radius = 2.0
result = self.sect.calculate(self.graph, override_bound_radius=override_radius)

# Check that thresholds are within the override radius
self.assertLessEqual(np.max(np.abs(result.thresholds)), override_radius)

def test_smooth_matrix_properties(self):
"""Test properties of the smoothed matrix"""
result = self.sect.calculate(self.graph)

# Smoothed values should be finite
self.assertTrue(np.all(np.isfinite(result)))

# Shape should be preserved after smoothing
self.assertEqual(result.shape, (self.num_dirs, self.num_thresh))

# Verify result is float type after smoothing
self.assertTrue(np.issubdtype(result.dtype, np.floating))