|
| 1 | +"""Projection components for dimensionality reduction in communication systems. |
| 2 | +
|
| 3 | +This module implements projection methods that can be used for dimensionality reduction in |
| 4 | +communication systems, as described and implemented in |
| 5 | +:cite:`yilmaz2025learning,yilmaz2025private`. |
| 6 | +""" |
| 7 | + |
| 8 | +from enum import Enum |
| 9 | +from typing import Any, Optional, Union |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.nn as nn |
| 13 | + |
| 14 | +from kaira.models.base import BaseModel |
| 15 | + |
| 16 | +from ..registry import ModelRegistry |
| 17 | + |
| 18 | + |
| 19 | +class ProjectionType(Enum): |
| 20 | + """Enum for different types of projections. |
| 21 | +
|
| 22 | + Attributes: |
| 23 | + RADEMACHER: Random binary projection matrix with values {-1, 1}. |
| 24 | + Suitable for fast computation and memory-efficient implementations. |
| 25 | + GAUSSIAN: Random projection matrix with values drawn from N(0, 1/out_features). |
| 26 | + Provides good theoretical guarantees for dimensionality reduction. |
| 27 | + ORTHOGONAL: Random orthogonal matrix with columns that form an orthonormal basis. |
| 28 | + Preserves angles and distances better than non-orthogonal projections. |
| 29 | + COMPLEX_GAUSSIAN: Complex-valued projection with real and imaginary parts from N(0, 1/(2*out_features)). |
| 30 | + Useful for complex signal processing and wireless communications. |
| 31 | + COMPLEX_ORTHOGONAL: Complex-valued orthogonal projection with orthonormal columns. |
| 32 | + Provides optimal preservation of signal characteristics for complex data. |
| 33 | + """ |
| 34 | + |
| 35 | + RADEMACHER = "rademacher" |
| 36 | + GAUSSIAN = "gaussian" |
| 37 | + ORTHOGONAL = "orthogonal" |
| 38 | + COMPLEX_GAUSSIAN = "complex_gaussian" |
| 39 | + COMPLEX_ORTHOGONAL = "complex_orthogonal" |
| 40 | + |
| 41 | + |
| 42 | +@ModelRegistry.register_model() |
| 43 | +class Projection(BaseModel): |
| 44 | + """Projection layer for dimensionality reduction in communication systems |
| 45 | + :cite:`yilmaz2025private,yilmaz2025learning`. |
| 46 | +
|
| 47 | + This module implements different projection methods that can be used for dimensionality |
| 48 | + reduction in communication systems. These projection methods have been adapted from those |
| 49 | + used in :cite:`yilmaz2025private` and :cite:`yilmaz2025learning`. The projection only operates |
| 50 | + on the last dimension of the input tensor and uses matrix multiplication. |
| 51 | +
|
| 52 | + Available projection types: |
| 53 | + * RADEMACHER: Random matrix with values {-1, 1} (binary) |
| 54 | + * GAUSSIAN: Random matrix with values from N(0, 1/out_features) |
| 55 | + * ORTHOGONAL: Matrix with orthogonal columns (real-valued) |
| 56 | + * COMPLEX_GAUSSIAN: Complex matrix with real and imaginary parts from N(0, 1/(2*out_features)) |
| 57 | + * COMPLEX_ORTHOGONAL: Complex matrix with orthogonal columns |
| 58 | +
|
| 59 | + Complex projections are particularly useful for wireless communication systems |
| 60 | + where signals are often represented in the complex domain with I/Q components. |
| 61 | + """ |
| 62 | + |
| 63 | + def __init__( |
| 64 | + self, |
| 65 | + in_features: int, |
| 66 | + out_features: int, |
| 67 | + projection_type: Union[ProjectionType, str] = ProjectionType.ORTHOGONAL, |
| 68 | + seed: Optional[int] = None, |
| 69 | + trainable: bool = True, |
| 70 | + dtype: Optional[torch.dtype] = None, |
| 71 | + *args: Any, |
| 72 | + **kwargs: Any, |
| 73 | + ): |
| 74 | + """Initialize the Projection layer. |
| 75 | +
|
| 76 | + Args: |
| 77 | + in_features (int): The dimensionality of the input features. |
| 78 | + out_features (int): The dimensionality of the output features. |
| 79 | + projection_type (ProjectionType or str, optional): Type of projection to use. |
| 80 | + Possible values as enum: ProjectionType.RADEMACHER, ProjectionType.GAUSSIAN, |
| 81 | + ProjectionType.ORTHOGONAL, ProjectionType.COMPLEX_GAUSSIAN, ProjectionType.COMPLEX_ORTHOGONAL. |
| 82 | + Possible values as str: "rademacher", "gaussian", "orthogonal", "complex_gaussian", "complex_orthogonal". |
| 83 | + Default is ProjectionType.ORTHOGONAL. |
| 84 | + seed (int, optional): Random seed for reproducibility. Default is None. |
| 85 | + trainable (bool, optional): Whether the projection matrix is trainable. |
| 86 | + Default is True. |
| 87 | + dtype (torch.dtype, optional): The dtype of the projection matrix. |
| 88 | + Default is None, which will use float32 for real projections and complex64 |
| 89 | + for complex projections. |
| 90 | + *args: Variable positional arguments passed to the base class. |
| 91 | + **kwargs: Variable keyword arguments passed to the base class. |
| 92 | + """ |
| 93 | + super().__init__(*args, **kwargs) |
| 94 | + |
| 95 | + self.in_features = in_features |
| 96 | + self.out_features = out_features |
| 97 | + self.projection_type = projection_type if isinstance(projection_type, ProjectionType) else ProjectionType(projection_type) |
| 98 | + self.seed = seed |
| 99 | + self.trainable = trainable |
| 100 | + |
| 101 | + # Determine if we're using a complex projection |
| 102 | + self.is_complex = self.projection_type in [ProjectionType.COMPLEX_GAUSSIAN, ProjectionType.COMPLEX_ORTHOGONAL] |
| 103 | + |
| 104 | + # Determine dtype based on input or defaults |
| 105 | + if dtype is None: |
| 106 | + self.dtype = torch.complex64 if self.is_complex else torch.float32 |
| 107 | + else: |
| 108 | + self.dtype = dtype |
| 109 | + |
| 110 | + # Create local RNG for PyTorch |
| 111 | + torch_rng = torch.Generator() |
| 112 | + |
| 113 | + # Set seed for the local RNG if provided |
| 114 | + if seed is not None: |
| 115 | + torch_rng.manual_seed(seed) |
| 116 | + |
| 117 | + # Initialize projection matrix based on the specified type |
| 118 | + if self.projection_type == ProjectionType.RADEMACHER: |
| 119 | + # Rademacher distribution: Random matrix with values {-1, 1} |
| 120 | + projection = (torch.randint(0, 2, (in_features, out_features), generator=torch_rng) * 2 - 1).to(self.dtype) |
| 121 | + elif self.projection_type == ProjectionType.GAUSSIAN: |
| 122 | + # Gaussian distribution: Random matrix with values from N(0, 1/out_features) |
| 123 | + projection = (torch.randn(in_features, out_features, generator=torch_rng) / torch.sqrt(torch.tensor(out_features, dtype=torch.float32))).to(self.dtype) |
| 124 | + elif self.projection_type == ProjectionType.ORTHOGONAL: |
| 125 | + # Orthogonal matrix: Using QR decomposition for orthogonal initialization |
| 126 | + random_matrix = torch.randn(max(in_features, out_features), min(in_features, out_features), generator=torch_rng) |
| 127 | + q, r = torch.linalg.qr(random_matrix) |
| 128 | + # Use the sign of diagonal elements of r to ensure deterministic results |
| 129 | + d = torch.diagonal(r) |
| 130 | + ph = d.sign() |
| 131 | + q *= ph |
| 132 | + |
| 133 | + if in_features >= out_features: |
| 134 | + projection = q[:in_features, :out_features].to(self.dtype) |
| 135 | + else: |
| 136 | + projection = q[:in_features, :out_features].t().to(self.dtype) |
| 137 | + elif self.projection_type == ProjectionType.COMPLEX_GAUSSIAN: |
| 138 | + # Complex Gaussian: Real and imaginary parts from N(0, 1/(2*out_features)) |
| 139 | + # Factor of 1/2 ensures same expected power as real Gaussian |
| 140 | + real_part = torch.randn(in_features, out_features, generator=torch_rng) / torch.sqrt(torch.tensor(2 * out_features, dtype=torch.float32)) |
| 141 | + imag_part = torch.randn(in_features, out_features, generator=torch_rng) / torch.sqrt(torch.tensor(2 * out_features, dtype=torch.float32)) |
| 142 | + projection = torch.complex(real_part, imag_part).to(self.dtype) |
| 143 | + elif self.projection_type == ProjectionType.COMPLEX_ORTHOGONAL: |
| 144 | + # Complex Orthogonal matrix: Generate a random complex matrix and orthogonalize |
| 145 | + real_part = torch.randn(max(in_features, out_features), min(in_features, out_features), generator=torch_rng) |
| 146 | + imag_part = torch.randn(max(in_features, out_features), min(in_features, out_features), generator=torch_rng) |
| 147 | + random_matrix = torch.complex(real_part, imag_part) |
| 148 | + |
| 149 | + # Use QR decomposition to get an orthogonal basis |
| 150 | + q, r = torch.linalg.qr(random_matrix) |
| 151 | + |
| 152 | + # Normalize phases to ensure deterministic results |
| 153 | + d = torch.diagonal(r) |
| 154 | + ph = d / torch.abs(d) # Unit complex numbers preserving phase |
| 155 | + q *= ph |
| 156 | + |
| 157 | + if in_features >= out_features: |
| 158 | + projection = q[:in_features, :out_features].to(self.dtype) |
| 159 | + else: |
| 160 | + projection = q[:in_features, :out_features].t().to(self.dtype) |
| 161 | + else: |
| 162 | + raise ValueError(f"Unknown projection type: {projection_type}") |
| 163 | + |
| 164 | + # Register the projection matrix as a parameter or buffer |
| 165 | + if trainable: |
| 166 | + self.projection = nn.Parameter(projection) |
| 167 | + else: |
| 168 | + self.register_buffer("projection", projection) |
| 169 | + |
| 170 | + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
| 171 | + """Forward pass of the Projection layer. |
| 172 | +
|
| 173 | + Args: |
| 174 | + x (torch.Tensor): Input tensor with the last dimension being the features. |
| 175 | + For complex projections, x can be either a complex tensor or a real tensor. |
| 176 | + If x is real and the projection is complex, x will be treated as having |
| 177 | + only real components. |
| 178 | + *args: Additional positional arguments (unused). |
| 179 | + **kwargs: Additional keyword arguments (unused). |
| 180 | +
|
| 181 | + Returns: |
| 182 | + torch.Tensor: Output tensor with the last dimension projected. |
| 183 | + If the projection is complex, the output will be complex. |
| 184 | + """ |
| 185 | + # Handle type conversions if needed |
| 186 | + if self.is_complex and not torch.is_complex(x): |
| 187 | + # Real input with complex projection - treat input as having only real part |
| 188 | + x = torch.complex(x, torch.zeros_like(x)) |
| 189 | + |
| 190 | + # Perform matrix multiplication on the last dimension |
| 191 | + return x @ self.projection |
| 192 | + |
| 193 | + def extra_repr(self) -> str: |
| 194 | + """Return extra representation string for the module.""" |
| 195 | + return f"in_features={self.in_features}, out_features={self.out_features}, " f"projection_type={self.projection_type.value}, is_complex={self.is_complex}, " f"dtype={self.dtype}, trainable={self.trainable}" |
0 commit comments