Skip to content

Commit ec8b11d

Browse files
committed
core\refac: #68 init tfds for crowd seg
- init tfds style dataset for crowd seg data
1 parent 5373c43 commit ec8b11d

File tree

24 files changed

+676
-539
lines changed

24 files changed

+676
-539
lines changed

core/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
description = "Framework for handling image segmentation in the context of multiple annotators"
33
name = "seg_tgce"
4-
version = "0.2.1.dev3"
4+
version = "0.2.2.dev1"
55
readme = "README.md"
66
authors = [{ name = "Brandon Lotero", email = "blotero@gmail.com" }]
77
maintainers = [{ name = "Brandon Lotero", email = "blotero@gmail.com" }]
@@ -15,7 +15,7 @@ Issues = "https://github.com/blotero/seg_tgce/issues"
1515

1616
[tool.poetry]
1717
name = "seg_tgce"
18-
version = "0.2.1.dev3"
18+
version = "0.2.2.dev1"
1919
authors = ["Brandon Lotero <blotero@gmail.com>"]
2020
description = "A package for the SEG TGCE project"
2121
readme = "README.md"
Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,3 @@
1-
from typing import Tuple
1+
from .generator import CrowdSegDataGenerator, Stage, get_crowd_seg_data
22

3-
from .generator import CrowdSegDataGenerator, DataSchema
4-
from .stage import Stage
5-
6-
DEFAULT_TARGET_SIZE = (512, 512)
7-
8-
9-
def get_all_data(
10-
image_size: Tuple[int, int] = DEFAULT_TARGET_SIZE,
11-
batch_size: int = 32,
12-
shuffle: bool = False,
13-
with_sparse_data: bool = False,
14-
trim_n_scorers: int | None = None,
15-
) -> Tuple[CrowdSegDataGenerator, ...]:
16-
"""
17-
Retrieve all data generators for the crowd segmentation task.
18-
returns a tuple of ImageDataGenerator instances for the train, val, and test stages.
19-
"""
20-
return tuple(
21-
CrowdSegDataGenerator(
22-
batch_size=batch_size,
23-
image_size=image_size,
24-
shuffle=shuffle,
25-
stage=stage,
26-
schema=DataSchema.MA_SPARSE if with_sparse_data else DataSchema.MA_RAW,
27-
trim_n_scorers=trim_n_scorers,
28-
)
29-
for stage in (Stage.TRAIN, Stage.VAL, Stage.TEST)
30-
)
31-
32-
33-
def get_stage_data(
34-
stage: Stage,
35-
image_size: Tuple[int, int] = DEFAULT_TARGET_SIZE,
36-
batch_size: int = 32,
37-
shuffle: bool = False,
38-
with_sparse_data: bool = False,
39-
) -> CrowdSegDataGenerator:
40-
"""
41-
Retrieve a data generator for a specific stage of the crowd segmentation task.
42-
"""
43-
return CrowdSegDataGenerator(
44-
batch_size=batch_size,
45-
image_size=image_size,
46-
shuffle=shuffle,
47-
stage=stage,
48-
schema=DataSchema.MA_SPARSE if with_sparse_data else DataSchema.MA_RAW,
49-
)
3+
__all__ = ["CrowdSegDataGenerator", "Stage", "get_crowd_seg_data"]
Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,23 @@
11
from matplotlib import pyplot as plt
22

3-
from seg_tgce.data.crowd_seg import get_all_data
4-
from seg_tgce.data.crowd_seg.generator import CrowdSegDataGenerator
3+
from seg_tgce.data.crowd_seg import get_crowd_seg_data
54

65

76
def main() -> None:
87
print("Loading data...")
9-
train, val, test = get_all_data(batch_size=16)
8+
train, val, test = get_crowd_seg_data(batch_size=128)
109

11-
# Get a sample batch from each generator
12-
train_batch = next(iter(train))
13-
val_batch = next(iter(val))
14-
test_batch = next(iter(test))
15-
16-
# Print shapes
17-
print("\nTrain data shapes:")
18-
print(f"Images shape: {train_batch[0].shape}")
19-
print(f"Ground truth mask shape: {train_batch[1].shape}")
20-
print(f"Labeler masks shape: {train_batch[2].shape}")
21-
22-
print("\nValidation data shapes:")
23-
print(f"Images shape: {val_batch[0].shape}")
24-
print(f"Ground truth mask shape: {val_batch[1].shape}")
25-
print(f"Labeler masks shape: {val_batch[2].shape}")
26-
27-
print("\nTest data shapes:")
28-
print(f"Images shape: {test_batch[0].shape}")
29-
print(f"Ground truth mask shape: {test_batch[1].shape}")
30-
print(f"Labeler masks shape: {test_batch[2].shape}")
31-
32-
fig = val.visualize_sample(batch_index=75, sample_indexes=[0, 1, 4, 5])
33-
fig.tight_layout()
34-
fig.savefig(
35-
"/home/brandon/unal/maestria/master_thesis/Cap1/Figures/multiannotator-segmentation.png"
10+
fig = train.visualize_sample(
11+
batch_index=6, sample_indexes=[0, 1, 30, 31, 63, 64, 126, 127]
3612
)
13+
fig.tight_layout()
14+
# fig.savefig(
15+
# "/home/brandon/unal/maestria/master_thesis/Cap1/Figures/multiannotator-segmentation.png"
16+
# )
17+
plt.show()
3718
print(f"Train: {len(train)} batches, {len(train) * train.batch_size} samples")
3819
print(f"Val: {len(val)} batches, {len(val) * val.batch_size} samples")
3920
print(f"Test: {len(test)} batches, {len(test) * test.batch_size} samples")
4021

41-
print("Loading train data with trimmed scorers...")
42-
train = CrowdSegDataGenerator(
43-
batch_size=8,
44-
trim_n_scorers=6,
45-
)
46-
print(f"Train: {len(train)} batches, {len(train) * train.batch_size} samples")
47-
print(f"Train scorers tags: {train.scorers_tags}")
48-
4922

5023
main()

core/seg_tgce/data/crowd_seg/__retrieve.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from botocore import UNSIGNED
77
from botocore.client import Config
88

9-
from .stage import Stage
9+
from seg_tgce.data.crowd_seg.types import Stage
1010

1111
LOGGER = logging.getLogger(__name__)
1212
logging.basicConfig(level=logging.INFO)
@@ -21,11 +21,11 @@
2121

2222

2323
def get_masks_dir(stage: Stage) -> str:
24-
return os.path.join(_TARGET_DIR, "masks", stage.value)
24+
return os.path.join(_TARGET_DIR, "masks", stage.capitalize())
2525

2626

2727
def get_patches_dir(stage: Stage) -> str:
28-
return os.path.join(_TARGET_DIR, "patches", stage.value)
28+
return os.path.join(_TARGET_DIR, "patches", stage.capitalize())
2929

3030

3131
def _unzip_dirs() -> None:
@@ -62,13 +62,9 @@ def verify_path(path: str, with_raise: bool = False) -> bool:
6262

6363

6464
def fetch_data() -> None:
65-
paths_to_verify = [
66-
get_patches_dir(Stage.TRAIN),
67-
get_patches_dir(Stage.VAL),
68-
get_patches_dir(Stage.TEST),
69-
get_masks_dir(Stage.TRAIN),
70-
get_masks_dir(Stage.VAL),
71-
get_masks_dir(Stage.TEST),
65+
stages: tuple[Stage, ...] = ("train", "val", "test")
66+
paths_to_verify = [get_patches_dir(stage) for stage in stages] + [
67+
get_masks_dir(stage) for stage in stages
7268
]
7369
if all(verify_path(path) for path in paths_to_verify):
7470
return

0 commit comments

Comments
 (0)