Skip to content

Commit 1824e78

Browse files
committed
Merge branch 'main' of github.com:codinglabsong/aging-gan
2 parents 20b77f4 + 05757fb commit 1824e78

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

src/aging_gan/data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212

1313
class UTKFace(Dataset):
1414
"""
15-
Assumes the unzipped UTKFace images live in <root>/data/UTKFace
15+
Assumes the unzipped aligned UTKFace images live in <root>/data/utkface_aligned_cropped/UTKFace
1616
File pattern: {age}_{gender}_{race}_{yyyymmddHHMMSS}.jpg
1717
"""
1818

1919
def __init__(self, root: str, transform: T.Compose | None = None):
2020
self.root = (
21-
Path(root) / "UTKFace" # "utkface_aligned_cropped" /
21+
Path(root) / "utkface_aligned_cropped" / "UTKFace"
2222
) # or "UTKFace" for the unaligned and varied original version.
2323
self.files = sorted(f for f in self.root.glob("*.jpg"))
2424
if not self.files:
@@ -132,7 +132,7 @@ def prepare_dataset(
132132
# randomness
133133
train_transform = T.Compose(
134134
[
135-
T.ToPILImage(),
135+
# T.ToPILImage(),
136136
T.RandomHorizontalFlip(),
137137
T.Resize((img_size + 50, img_size + 50), antialias=True),
138138
T.RandomCrop(img_size),

src/aging_gan/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def parse_args() -> argparse.Namespace:
9494
p.add_argument(
9595
"--steps_for_logging_metrics",
9696
type=int,
97-
default=1,
97+
default=50,
9898
help="Print training metrics after certain batch steps.",
9999
)
100100
p.add_argument(

src/aging_gan/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,18 @@ def generate_and_save_samples(
103103
device: torch.device,
104104
num_samples: int = 8,
105105
):
106-
# grab one batch
107-
inputs, _ = next(iter(val_loader))
108-
inputs = inputs.to(device)[:num_samples]
106+
# grab batches until num_samples
107+
collected = []
108+
for imgs, _ in val_loader:
109+
collected.append(imgs)
110+
if sum(b.size(0) for b in collected) >= num_samples:
111+
break
112+
113+
if not collected:
114+
raise ValueError("Validation loader is empty.")
115+
116+
inputs = torch.cat(collected, dim=0)[:num_samples].to(device)
117+
109118
with torch.no_grad():
110119
outputs = generator(inputs)
111120

0 commit comments

Comments
 (0)