Skip to content

Commit 05757fb

Browse files
author
Ubuntu
committed
small bugs
1 parent 9a24166 commit 05757fb

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

src/aging_gan/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def prepare_dataset(
132132
# randomness
133133
train_transform = T.Compose(
134134
[
135-
T.ToPILImage(),
135+
# T.ToPILImage(),
136136
T.RandomHorizontalFlip(),
137137
T.Resize((img_size + 50, img_size + 50), antialias=True),
138138
T.RandomCrop(img_size),

src/aging_gan/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def parse_args() -> argparse.Namespace:
9494
p.add_argument(
9595
"--steps_for_logging_metrics",
9696
type=int,
97-
default=1,
97+
default=50,
9898
help="Print training metrics after certain batch steps.",
9999
)
100100
p.add_argument(

src/aging_gan/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,18 @@ def generate_and_save_samples(
103103
device: torch.device,
104104
num_samples: int = 8,
105105
):
106-
# grab one batch
107-
inputs, _ = next(iter(val_loader))
108-
inputs = inputs.to(device)[:num_samples]
106+
# grab batches until num_samples
107+
collected = []
108+
for imgs, _ in val_loader:
109+
collected.append(imgs)
110+
if sum(b.size(0) for b in collected) >= num_samples:
111+
break
112+
113+
if not collected:
114+
raise ValueError("Validation loader is empty.")
115+
116+
inputs = torch.cat(collected, dim=0)[:num_samples].to(device)
117+
109118
with torch.no_grad():
110119
outputs = generator(inputs)
111120

0 commit comments

Comments
 (0)