File tree Expand file tree Collapse file tree 3 files changed +16
-7
lines changed Expand file tree Collapse file tree 3 files changed +16
-7
lines changed Original file line number Diff line number Diff line change 1212
1313class UTKFace (Dataset ):
1414 """
15- Assumes the unzipped UTKFace images live in <root>/data/UTKFace
15+ Assumes the unzipped aligned UTKFace images live in <root>/data/utkface_aligned_cropped /UTKFace
1616 File pattern: {age}_{gender}_{race}_{yyyymmddHHMMSS}.jpg
1717 """
1818
1919 def __init__ (self , root : str , transform : T .Compose | None = None ):
2020 self .root = (
21- Path (root ) / "UTKFace" # "utkface_aligned_cropped" /
21+ Path (root ) / "utkface_aligned_cropped" / "UTKFace"
2222 ) # or "UTKFace" for the unaligned and varied original version.
2323 self .files = sorted (f for f in self .root .glob ("*.jpg" ))
2424 if not self .files :
@@ -132,7 +132,7 @@ def prepare_dataset(
132132 # randomness
133133 train_transform = T .Compose (
134134 [
135- T .ToPILImage (),
135+ # T.ToPILImage(),
136136 T .RandomHorizontalFlip (),
137137 T .Resize ((img_size + 50 , img_size + 50 ), antialias = True ),
138138 T .RandomCrop (img_size ),
Original file line number Diff line number Diff line change @@ -94,7 +94,7 @@ def parse_args() -> argparse.Namespace:
9494 p .add_argument (
9595 "--steps_for_logging_metrics" ,
9696 type = int ,
97- default = 1 ,
97+ default = 50 ,
9898 help = "Print training metrics after certain batch steps." ,
9999 )
100100 p .add_argument (
Original file line number Diff line number Diff line change @@ -103,9 +103,18 @@ def generate_and_save_samples(
103103 device : torch .device ,
104104 num_samples : int = 8 ,
105105):
106- # grab one batch
107- inputs , _ = next (iter (val_loader ))
108- inputs = inputs .to (device )[:num_samples ]
106+ # grab batches until num_samples
107+ collected = []
108+ for imgs , _ in val_loader :
109+ collected .append (imgs )
110+ if sum (b .size (0 ) for b in collected ) >= num_samples :
111+ break
112+
113+ if not collected :
114+ raise ValueError ("Validation loader is empty." )
115+
116+ inputs = torch .cat (collected , dim = 0 )[:num_samples ].to (device )
117+
109118 with torch .no_grad ():
110119 outputs = generator (inputs )
111120
You can’t perform that action at this time.
0 commit comments