Skip to content

Commit 45513bc

Browse files
committed
Updated docstrings
1 parent dcd52a4 commit 45513bc

File tree

3 files changed

+26
-26
lines changed

3 files changed

+26
-26
lines changed

utils/dataloaders/svhn.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,28 @@ def __init__(
1616
nr_channels=3,
1717
):
1818
"""
19-
Initializes the SVHNDataset object.
19+
Initializes the SVHNDataset object for loading the Street View House Numbers (SVHN) dataset.
2020
Args:
21-
data_path (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
22-
transforms: Torch composite of transformations which are to be applied to the dataset images.
23-
download_data (bool): If True, downloads the dataset to the specified data_path.
24-
split (str): The dataset split to use, either 'train' or 'test'.
21+
data_path (str): Path to where the data is stored. If `download` is set to True, this is where the data will be downloaded.
22+
train (bool): If True, loads the training split of the dataset; otherwise, loads the test split.
23+
transform (callable, optional): A function/transform to apply to the images.
24+
download (bool): If True, downloads the dataset to the specified `data_path`.
25+
nr_channels (int): Number of channels in the images. Default is 3 for RGB images.
2526
Raises:
2627
AssertionError: If the split is not 'train' or 'test'.
2728
"""
2829
super().__init__()
29-
# assert split == "train" or split == "test"
3030
self.split = "train" if train else "test"
3131

3232
if download:
3333
self._download_data(data_path)
3434

3535
data = loadmat(os.path.join(data_path, f"{self.split}_32x32.mat"))
3636

37-
# Images on the form N x H x W x C
37+
# Reform images to the form N x H x W x C
3838
self.images = data["X"].transpose(3, 1, 0, 2)
3939
self.labels = data["y"].flatten()
40+
4041
self.labels[self.labels == 10] = 0
4142

4243
self.nr_channels = nr_channels
@@ -45,13 +46,11 @@ def __init__(
4546

4647
def _download_data(self, path: str):
4748
"""
48-
Downloads the SVHN dataset.
49+
Downloads the SVHN dataset to the specified directory.
4950
Args:
5051
path (str): The directory where the dataset will be downloaded.
51-
split (str): The dataset split to download, either 'train' or 'test'.
5252
"""
5353
print(f"Downloading SVHN data into {path}")
54-
5554
SVHN(path, split=self.split, download=True)
5655

5756
def __len__(self):
@@ -74,7 +73,6 @@ def __getitem__(self, index):
7473

7574
if self.nr_channels == 1:
7675
img = np.mean(img, axis=2, keepdims=True)
77-
7876
if self.transforms is not None:
7977
img = self.transforms(img)
8078

utils/metrics/EntropyPred.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,32 @@
55
class EntropyPrediction(nn.Module):
66
def __init__(self, averages: str = "average"):
77
"""
8-
Initializes the EntropyPrediction module.
8+
Initializes the EntropyPrediction module, which calculates the Shannon Entropy
9+
of predicted logits and aggregates the results based on the specified method.
910
Args:
1011
averages (str): Specifies the method of aggregation for entropy values.
11-
Must be either 'mean', 'sum' or 'none.
12+
Must be one of 'mean', 'sum', or 'none'.
1213
Raises:
13-
AssertionError: If the averages parameter is not 'mean' or 'sum'.
14+
AssertionError: If the averages parameter is not 'mean', 'sum', or 'none'.
1415
"""
1516
super().__init__()
1617

17-
assert averages == "mean" or averages == "sum"
18+
assert averages in ["mean", "sum", "none"], (
19+
"averages must be 'mean', 'sum', or 'none'"
20+
)
1821
self.averages = averages
1922
self.stored_entropy_values = []
2023

2124
def __call__(self, y_true, y_logits):
2225
"""
23-
Computes the Shannon Entropy of the predicted logits, storing the results.
26+
Computes the Shannon Entropy of the predicted logits and stores the results.
2427
Args:
25-
y_true: The true labels. Does nothing, but needed for compatability sake.
26-
y_logits: The predicted logits.
28+
y_true: The true labels. This parameter is not used in the computation
29+
but is included for compatibility with certain interfaces.
30+
y_logits: The predicted logits from which entropy is calculated.
31+
Returns:
32+
torch.Tensor: The aggregated entropy value(s) based on the specified
33+
method ('mean', 'sum', or 'none').
2734
"""
2835
entropy_values = entropy(y_logits, axis=1)
2936
entropy_values = th.from_numpy(entropy_values)
@@ -34,10 +41,8 @@ def __call__(self, y_true, y_logits):
3441

3542
if self.averages == "mean":
3643
entropy_values = th.mean(entropy_values)
37-
3844
elif self.averages == "sum":
3945
entropy_values = th.sum(entropy_values)
40-
4146
elif self.averages == "none":
4247
return entropy_values
4348

utils/models/magnus_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,16 @@ class MagnusModel(nn.Module):
55
def __init__(self, image_shape, num_classes: int, nr_channels: int):
66
"""
77
Initializes the MagnusModel, a neural network designed for image classification tasks.
8-
98
The model consists of three linear layers, each with 133 neurons, and uses ReLU activation
109
functions between the layers. The first layer's input size is determined by the image shape
1110
and number of channels, while the output layer's size is determined by the number of classes.
1211
Args:
13-
image_shape (tuple): A tuple representing the dimensions of the input image (Channels, Height, Width).
12+
image_shape (tuple): A tuple representing the dimensions of the input image (Height, Width).
1413
num_classes (int): The number of output classes for classification.
1514
nr_channels (int): The number of channels in the input image.
16-
Returns:
17-
MagnusModel (nn.Module): An instance of the MagnusModel neural network.
1815
"""
1916
super().__init__()
2017
*_, H, W = image_shape
21-
2218
self.layer1 = nn.Sequential(
2319
*(
2420
[
@@ -40,7 +36,8 @@ def forward(self, x):
4036
"""
4137
Defines the forward pass of the MagnusModel.
4238
Args:
43-
x (torch.Tensor): A four-dimensional tensor with shape (Batch Size, Channels, Image Height, Image Width).
39+
x (torch.Tensor): A four-dimensional tensor with shape
40+
(Batch Size, Channels, Image Height, Image Width).
4441
Returns:
4542
torch.Tensor: The output tensor containing class logits for each input sample.
4643
"""

0 commit comments

Comments
 (0)