Skip to content

Commit bff0e14

Browse files
committed
Add projection components for dimensionality reduction in communication systems
- Introduced a new module `projection.py` implementing various projection methods: Rademacher, Gaussian, Orthogonal, Complex Gaussian, and Complex Orthogonal. - Added `Projection` class for creating projection layers, supporting both real and complex projections. - Updated `__init__.py` to include `Projection` and `ProjectionType` in the module exports. - Created an example script `plot_projections_and_cover_tests.py` demonstrating the usage of projections and cover tests for evaluating projection quality in communication systems. - Visualizations include comparison of different projection types, projection matrices, distribution analysis, column orthogonality, and distance preservation tests.
1 parent c158257 commit bff0e14

File tree

4 files changed

+216
-1
lines changed

4 files changed

+216
-1
lines changed

docs/api_reference.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ Components module for Kaira models.
267267
ConvEncoder
268268
MLPDecoder
269269
MLPEncoder
270+
Projection
271+
ProjectionType
270272

271273

272274
Decoders

docs/references.bib

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,3 +823,20 @@ @book{blahut2003algebraic
823823
isbn={978-0-521-55374-5},
824824
doi={10.1017/CBO9780511800467}
825825
}
826+
827+
@article{yilmaz2025private,
828+
author={Yilmaz, Selim F. and Hasircioğlu, Burak and Qiao, Li and Gündüz, Deniz},
829+
journal={IEEE Transactions on Machine Learning in Communications and Networking},
830+
title={Private Collaborative Edge Inference via Over-the-Air Computation},
831+
year={2025},
832+
volume={3},
833+
number={},
834+
pages={215-231},
835+
doi={10.1109/TMLCN.2025.3526551}}
836+
837+
@article{yilmaz2025learning,
838+
title={Learning to Interfere in Non-Orthogonal Multiple-Access Joint Source-Channel Coding},
839+
author={Yilmaz, Selim F and Karamanli, Can and Gunduz, Deniz},
840+
journal={arXiv preprint arXiv:2504.03690},
841+
year={2025}
842+
}

kaira/models/components/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from .afmodule import AFModule
44
from .conv import ConvDecoder, ConvEncoder
55
from .mlp import MLPDecoder, MLPEncoder
6+
from .projection import Projection, ProjectionType
67

7-
__all__ = ["AFModule", "MLPEncoder", "MLPDecoder", "ConvEncoder", "ConvDecoder"]
8+
__all__ = ["AFModule", "MLPEncoder", "MLPDecoder", "ConvEncoder", "ConvDecoder", "Projection", "ProjectionType"]
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""Projection components for dimensionality reduction in communication systems.
2+
3+
This module implements projection methods that can be used for dimensionality reduction in
4+
communication systems, as described and implemented in
5+
:cite:`yilmaz2025learning,yilmaz2025private`.
6+
"""
7+
8+
from enum import Enum
9+
from typing import Any, Optional, Union
10+
11+
import torch
12+
import torch.nn as nn
13+
14+
from kaira.models.base import BaseModel
15+
16+
from ..registry import ModelRegistry
17+
18+
19+
class ProjectionType(Enum):
20+
"""Enum for different types of projections.
21+
22+
Attributes:
23+
RADEMACHER: Random binary projection matrix with values {-1, 1}.
24+
Suitable for fast computation and memory-efficient implementations.
25+
GAUSSIAN: Random projection matrix with values drawn from N(0, 1/out_features).
26+
Provides good theoretical guarantees for dimensionality reduction.
27+
ORTHOGONAL: Random orthogonal matrix with columns that form an orthonormal basis.
28+
Preserves angles and distances better than non-orthogonal projections.
29+
COMPLEX_GAUSSIAN: Complex-valued projection with real and imaginary parts from N(0, 1/(2*out_features)).
30+
Useful for complex signal processing and wireless communications.
31+
COMPLEX_ORTHOGONAL: Complex-valued orthogonal projection with orthonormal columns.
32+
Provides optimal preservation of signal characteristics for complex data.
33+
"""
34+
35+
RADEMACHER = "rademacher"
36+
GAUSSIAN = "gaussian"
37+
ORTHOGONAL = "orthogonal"
38+
COMPLEX_GAUSSIAN = "complex_gaussian"
39+
COMPLEX_ORTHOGONAL = "complex_orthogonal"
40+
41+
42+
@ModelRegistry.register_model()
43+
class Projection(BaseModel):
44+
"""Projection layer for dimensionality reduction in communication systems
45+
:cite:`yilmaz2025private,yilmaz2025learning`.
46+
47+
This module implements different projection methods that can be used for dimensionality
48+
reduction in communication systems. These projection methods have been adapted from those
49+
used in :cite:`yilmaz2025private` and :cite:`yilmaz2025learning`. The projection only operates
50+
on the last dimension of the input tensor and uses matrix multiplication.
51+
52+
Available projection types:
53+
* RADEMACHER: Random matrix with values {-1, 1} (binary)
54+
* GAUSSIAN: Random matrix with values from N(0, 1/out_features)
55+
* ORTHOGONAL: Matrix with orthogonal columns (real-valued)
56+
* COMPLEX_GAUSSIAN: Complex matrix with real and imaginary parts from N(0, 1/(2*out_features))
57+
* COMPLEX_ORTHOGONAL: Complex matrix with orthogonal columns
58+
59+
Complex projections are particularly useful for wireless communication systems
60+
where signals are often represented in the complex domain with I/Q components.
61+
"""
62+
63+
def __init__(
64+
self,
65+
in_features: int,
66+
out_features: int,
67+
projection_type: Union[ProjectionType, str] = ProjectionType.ORTHOGONAL,
68+
seed: Optional[int] = None,
69+
trainable: bool = True,
70+
dtype: Optional[torch.dtype] = None,
71+
*args: Any,
72+
**kwargs: Any,
73+
):
74+
"""Initialize the Projection layer.
75+
76+
Args:
77+
in_features (int): The dimensionality of the input features.
78+
out_features (int): The dimensionality of the output features.
79+
projection_type (ProjectionType or str, optional): Type of projection to use.
80+
Possible values as enum: ProjectionType.RADEMACHER, ProjectionType.GAUSSIAN,
81+
ProjectionType.ORTHOGONAL, ProjectionType.COMPLEX_GAUSSIAN, ProjectionType.COMPLEX_ORTHOGONAL.
82+
Possible values as str: "rademacher", "gaussian", "orthogonal", "complex_gaussian", "complex_orthogonal".
83+
Default is ProjectionType.ORTHOGONAL.
84+
seed (int, optional): Random seed for reproducibility. Default is None.
85+
trainable (bool, optional): Whether the projection matrix is trainable.
86+
Default is True.
87+
dtype (torch.dtype, optional): The dtype of the projection matrix.
88+
Default is None, which will use float32 for real projections and complex64
89+
for complex projections.
90+
*args: Variable positional arguments passed to the base class.
91+
**kwargs: Variable keyword arguments passed to the base class.
92+
"""
93+
super().__init__(*args, **kwargs)
94+
95+
self.in_features = in_features
96+
self.out_features = out_features
97+
self.projection_type = projection_type if isinstance(projection_type, ProjectionType) else ProjectionType(projection_type)
98+
self.seed = seed
99+
self.trainable = trainable
100+
101+
# Determine if we're using a complex projection
102+
self.is_complex = self.projection_type in [ProjectionType.COMPLEX_GAUSSIAN, ProjectionType.COMPLEX_ORTHOGONAL]
103+
104+
# Determine dtype based on input or defaults
105+
if dtype is None:
106+
self.dtype = torch.complex64 if self.is_complex else torch.float32
107+
else:
108+
self.dtype = dtype
109+
110+
# Create local RNG for PyTorch
111+
torch_rng = torch.Generator()
112+
113+
# Set seed for the local RNG if provided
114+
if seed is not None:
115+
torch_rng.manual_seed(seed)
116+
117+
# Initialize projection matrix based on the specified type
118+
if self.projection_type == ProjectionType.RADEMACHER:
119+
# Rademacher distribution: Random matrix with values {-1, 1}
120+
projection = (torch.randint(0, 2, (in_features, out_features), generator=torch_rng) * 2 - 1).to(self.dtype)
121+
elif self.projection_type == ProjectionType.GAUSSIAN:
122+
# Gaussian distribution: Random matrix with values from N(0, 1/out_features)
123+
projection = (torch.randn(in_features, out_features, generator=torch_rng) / torch.sqrt(torch.tensor(out_features, dtype=torch.float32))).to(self.dtype)
124+
elif self.projection_type == ProjectionType.ORTHOGONAL:
125+
# Orthogonal matrix: Using QR decomposition for orthogonal initialization
126+
random_matrix = torch.randn(max(in_features, out_features), min(in_features, out_features), generator=torch_rng)
127+
q, r = torch.linalg.qr(random_matrix)
128+
# Use the sign of diagonal elements of r to ensure deterministic results
129+
d = torch.diagonal(r)
130+
ph = d.sign()
131+
q *= ph
132+
133+
if in_features >= out_features:
134+
projection = q[:in_features, :out_features].to(self.dtype)
135+
else:
136+
projection = q[:in_features, :out_features].t().to(self.dtype)
137+
elif self.projection_type == ProjectionType.COMPLEX_GAUSSIAN:
138+
# Complex Gaussian: Real and imaginary parts from N(0, 1/(2*out_features))
139+
# Factor of 1/2 ensures same expected power as real Gaussian
140+
real_part = torch.randn(in_features, out_features, generator=torch_rng) / torch.sqrt(torch.tensor(2 * out_features, dtype=torch.float32))
141+
imag_part = torch.randn(in_features, out_features, generator=torch_rng) / torch.sqrt(torch.tensor(2 * out_features, dtype=torch.float32))
142+
projection = torch.complex(real_part, imag_part).to(self.dtype)
143+
elif self.projection_type == ProjectionType.COMPLEX_ORTHOGONAL:
144+
# Complex Orthogonal matrix: Generate a random complex matrix and orthogonalize
145+
real_part = torch.randn(max(in_features, out_features), min(in_features, out_features), generator=torch_rng)
146+
imag_part = torch.randn(max(in_features, out_features), min(in_features, out_features), generator=torch_rng)
147+
random_matrix = torch.complex(real_part, imag_part)
148+
149+
# Use QR decomposition to get an orthogonal basis
150+
q, r = torch.linalg.qr(random_matrix)
151+
152+
# Normalize phases to ensure deterministic results
153+
d = torch.diagonal(r)
154+
ph = d / torch.abs(d) # Unit complex numbers preserving phase
155+
q *= ph
156+
157+
if in_features >= out_features:
158+
projection = q[:in_features, :out_features].to(self.dtype)
159+
else:
160+
projection = q[:in_features, :out_features].t().to(self.dtype)
161+
else:
162+
raise ValueError(f"Unknown projection type: {projection_type}")
163+
164+
# Register the projection matrix as a parameter or buffer
165+
if trainable:
166+
self.projection = nn.Parameter(projection)
167+
else:
168+
self.register_buffer("projection", projection)
169+
170+
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
171+
"""Forward pass of the Projection layer.
172+
173+
Args:
174+
x (torch.Tensor): Input tensor with the last dimension being the features.
175+
For complex projections, x can be either a complex tensor or a real tensor.
176+
If x is real and the projection is complex, x will be treated as having
177+
only real components.
178+
*args: Additional positional arguments (unused).
179+
**kwargs: Additional keyword arguments (unused).
180+
181+
Returns:
182+
torch.Tensor: Output tensor with the last dimension projected.
183+
If the projection is complex, the output will be complex.
184+
"""
185+
# Handle type conversions if needed
186+
if self.is_complex and not torch.is_complex(x):
187+
# Real input with complex projection - treat input as having only real part
188+
x = torch.complex(x, torch.zeros_like(x))
189+
190+
# Perform matrix multiplication on the last dimension
191+
return x @ self.projection
192+
193+
def extra_repr(self) -> str:
194+
"""Return extra representation string for the module."""
195+
return f"in_features={self.in_features}, out_features={self.out_features}, " f"projection_type={self.projection_type.value}, is_complex={self.is_complex}, " f"dtype={self.dtype}, trainable={self.trainable}"

0 commit comments

Comments
 (0)