Skip to content

Commit 6fec5b9

Browse files
committed
Auto format everything
1 parent 6b12cdc commit 6fec5b9

File tree

84 files changed

+4803
-3622
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+4803
-3622
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[tool.black]
2-
line-length = 120
3-
skip-string-normalization = true
2+
line-length = 100
3+
skip-string-normalization = false
44
target-version = ['py311']

scripts/add_pose_pseudolabels.py

Lines changed: 87 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -29,58 +29,66 @@
2929

3030
def setup_loader(args):
3131
ds = Hdf5PoseDataset(
32-
args.filename,
33-
transform=Compose([
34-
dtr.batch.offset_points_by_half_pixel, # For when pixels are considered grid cell centers
35-
]),
36-
monochrome=True)
32+
args.filename,
33+
transform=Compose(
34+
[
35+
dtr.batch.offset_points_by_half_pixel, # For when pixels are considered grid cell centers
36+
]
37+
),
38+
monochrome=True,
39+
)
3740
if args.dryrun:
3841
ds = Subset(ds, np.arange(10))
3942
N = len(ds)
40-
loader = dtr.PostprocessingLoader(ds, args.batchsize,
43+
loader = dtr.PostprocessingLoader(
44+
ds,
45+
args.batchsize,
4146
shuffle=False,
4247
num_workers=utils.num_workers(),
4348
postprocess=None,
44-
collate_fn= lambda samples : samples,
49+
collate_fn=lambda samples: samples,
4550
)
4651
return loader, ds
4752

4853

49-
def fit_batch(net : InferenceNetwork, batch : List[Batch]):
50-
images = [ s['image'] for s in batch ]
51-
rois = torch.stack([ s['roi'] for s in batch ])
52-
indices = torch.stack([ s['index'] for s in batch ])
53-
out = Predictor(net, focus_roi_expansion_factor=1.2).predict_batch(images, rois, )
54-
out = {
55-
k:out[k] for k in 'unnormalized_quat coord pt3d_68 shapeparam'.split()
56-
}
57-
out.update(index = indices)
54+
def fit_batch(net: InferenceNetwork, batch: List[Batch]):
55+
images = [s["image"] for s in batch]
56+
rois = torch.stack([s["roi"] for s in batch])
57+
indices = torch.stack([s["index"] for s in batch])
58+
out = Predictor(net, focus_roi_expansion_factor=1.2).predict_batch(
59+
images,
60+
rois,
61+
)
62+
out = {k: out[k] for k in "unnormalized_quat coord pt3d_68 shapeparam".split()}
63+
out.update(index=indices)
5864
return out
5965

6066

6167
def test_quats_average():
6268
def positivereal(q):
63-
s = np.sign(q[...,3])
64-
return q*s[...,None]
69+
s = np.sign(q[..., 3])
70+
return q * s[..., None]
71+
6572
from scipy.spatial.transform import Rotation
73+
6674
expected_quats = Rotation.random(10).as_quat()
6775
quats = Rotation.from_quat(np.repeat(expected_quats, 10, axis=0))
68-
offsets = Rotation.random(10*10).as_rotvec()*0.01
76+
offsets = Rotation.random(10 * 10).as_rotvec() * 0.01
6977
quats = quats * Rotation.from_rotvec(offsets)
70-
quats = quats.as_quat().reshape((10,10,4)).transpose(1,0,2)
78+
quats = quats.as_quat().reshape((10, 10, 4)).transpose(1, 0, 2)
7179
out = quat_average(quats)
72-
#print (positivereal(out) - positivereal(expected_quats))
73-
assert np.allclose(positivereal(out) , positivereal(expected_quats), atol=0.02)
80+
# print (positivereal(out) - positivereal(expected_quats))
81+
assert np.allclose(positivereal(out), positivereal(expected_quats), atol=0.02)
7482

7583

7684
@torch.no_grad()
7785
def fitall(args):
7886
assert all(isfile(f) for f in args.checkpoints)
79-
print ("Inferring from networks:", args.checkpoints)
87+
print("Inferring from networks:", args.checkpoints)
8088

81-
with h5py.File(args.filename, 'r+') as f:
89+
with h5py.File(args.filename, "r+") as f:
8290
g = f.require_group(args.hdfgroupname) if args.hdfgroupname else f
83-
for key in 'coords quats pt3d_68 shapeparams':
91+
for key in "coords quats pt3d_68 shapeparams":
8492
try:
8593
del g[key]
8694
except KeyError:
@@ -91,22 +99,20 @@ def fitall(args):
9199

92100
outputs_per_net = defaultdict(list)
93101
for modelfile in tqdm.tqdm(args.checkpoints, desc="Network"):
94-
net = load_pose_network(modelfile, 'cuda')
95-
outputs = [ fit_batch(net, batch) for batch in tqdm.tqdm(loader, "Batch") ]
102+
net = load_pose_network(modelfile, "cuda")
103+
outputs = [fit_batch(net, batch) for batch in tqdm.tqdm(loader, "Batch")]
96104
outputs = utils.list_of_dicts_to_dict_of_lists(outputs)
97-
outputs = {k:np.concatenate(v,axis=0) for k,v in outputs.items() }
98-
ordering = np.argsort(outputs.pop('index'))
99-
outputs = { k:v[ordering] for k,v in outputs.items() }
100-
for k,v in outputs.items():
105+
outputs = {k: np.concatenate(v, axis=0) for k, v in outputs.items()}
106+
ordering = np.argsort(outputs.pop("index"))
107+
outputs = {k: v[ordering] for k, v in outputs.items()}
108+
for k, v in outputs.items():
101109
outputs_per_net[k].append(v)
102110
del outputs
103-
outputs_per_net = {
104-
k:np.stack(v) for k,v in outputs_per_net.items()
105-
}
111+
outputs_per_net = {k: np.stack(v) for k, v in outputs_per_net.items()}
106112

107113
del loader
108114
del ds
109-
gc.collect() # Ensure the hdf5 file in the data was really closed.
115+
gc.collect() # Ensure the hdf5 file in the data was really closed.
110116
# There is no way to enforce it. We can only hope the garbage
111117
# collector will destroy the objects. If there is still a reference
112118
# left it will be read-only and lead to failure when trying to write
@@ -115,30 +121,56 @@ def fitall(args):
115121
# FIXME: final quats output is busted. Values are more or less garbage.
116122
# unnormalized_quat looks fine!
117123

118-
quats = quat_average(outputs_per_net.pop('unnormalized_quat'))
119-
coords = np.average(outputs_per_net.pop('coord'), axis=0)
120-
pt3d_68 = np.average(outputs_per_net.pop('pt3d_68'), axis=0)
121-
shapeparams = np.average(outputs_per_net.pop('shapeparam'), axis=0)
124+
quats = quat_average(outputs_per_net.pop("unnormalized_quat"))
125+
coords = np.average(outputs_per_net.pop("coord"), axis=0)
126+
pt3d_68 = np.average(outputs_per_net.pop("pt3d_68"), axis=0)
127+
shapeparams = np.average(outputs_per_net.pop("shapeparam"), axis=0)
122128

123129
assert len(quats) == num_samples
124130

125-
with h5py.File(args.filename, 'r+') as f:
131+
with h5py.File(args.filename, "r+") as f:
126132
g = f.require_group(args.hdfgroupname) if args.hdfgroupname else f
127-
ds_quats = create_pose_dataset(g, C.quat, count=num_samples, data=quats, exists_ok=args.overwrite)
128-
ds_coords = create_pose_dataset(g, C.xys, count=num_samples, data=coords, exists_ok=args.overwrite)
129-
ds_pt3d_68 = create_pose_dataset(g, C.points, name='pt3d_68', count=num_samples, shape_wo_batch_dim=(68,3), data=pt3d_68, exists_ok=args.overwrite)
130-
ds_shapeparams = create_pose_dataset(g,C.general, name='shapeparams', count=num_samples, shape_wo_batch_dim=(50,), data = shapeparams, exists_ok=args.overwrite)
131-
132-
133-
if __name__== '__main__':
133+
ds_quats = create_pose_dataset(
134+
g, C.quat, count=num_samples, data=quats, exists_ok=args.overwrite
135+
)
136+
ds_coords = create_pose_dataset(
137+
g, C.xys, count=num_samples, data=coords, exists_ok=args.overwrite
138+
)
139+
ds_pt3d_68 = create_pose_dataset(
140+
g,
141+
C.points,
142+
name="pt3d_68",
143+
count=num_samples,
144+
shape_wo_batch_dim=(68, 3),
145+
data=pt3d_68,
146+
exists_ok=args.overwrite,
147+
)
148+
ds_shapeparams = create_pose_dataset(
149+
g,
150+
C.general,
151+
name="shapeparams",
152+
count=num_samples,
153+
shape_wo_batch_dim=(50,),
154+
data=shapeparams,
155+
exists_ok=args.overwrite,
156+
)
157+
158+
159+
if __name__ == "__main__":
134160
test_quats_average()
135161
parser = argparse.ArgumentParser()
136-
parser.add_argument('filename', type=str, help='the dataset to label')
137-
parser.add_argument('-c','--checkpoints', help='model checkpoint', nargs='*', type=str)
138-
parser.add_argument('-b','--batchsize', help="The batch size", type=int, default=512)
139-
parser.add_argument('--hdf-group-name', help="Group to store the annotations in", type=str, default='', dest='hdfgroupname')
140-
parser.add_argument('--dryrun', default=False, action='store_true')
141-
parser.add_argument('--overwrite', '-f', default=False, action='store_true')
162+
parser.add_argument("filename", type=str, help="the dataset to label")
163+
parser.add_argument("-c", "--checkpoints", help="model checkpoint", nargs="*", type=str)
164+
parser.add_argument("-b", "--batchsize", help="The batch size", type=int, default=512)
165+
parser.add_argument(
166+
"--hdf-group-name",
167+
help="Group to store the annotations in",
168+
type=str,
169+
default="",
170+
dest="hdfgroupname",
171+
)
172+
parser.add_argument("--dryrun", default=False, action="store_true")
173+
parser.add_argument("--overwrite", "-f", default=False, action="store_true")
142174
args = parser.parse_args()
143-
args.device = 'cuda'
144-
fitall(args)
175+
args.device = "cuda"
176+
fitall(args)

scripts/create_aflw2k3d_closed_eyes.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,37 @@
88
from face3drotationaugmentation.generate import augment_eyes_only, make_sample_for_passthrough
99
from face3drotationaugmentation.datasetwriter import dataset_writer
1010

11-
deg2rad = np.pi/180.
11+
deg2rad = np.pi / 180.0
1212

1313

14-
def main(filename : str, outputfilename : str, max_num_frames : int, prob_closed_eyes : float):
14+
def main(filename: str, outputfilename: str, max_num_frames: int, prob_closed_eyes: float):
1515
rng = np.random.RandomState(seed=1234567)
1616

1717
with closing(DatasetAFLW2k3D(filename)) as ds300wlp, dataset_writer(outputfilename) as writer:
1818
num_frames = min(max_num_frames, len(ds300wlp))
1919
for _, sample in tqdm.tqdm(zip(range(num_frames), ds300wlp), total=num_frames):
20-
if sample['scale'] <= 0.: # TODO: actual decent validation?
21-
print (f"Error: invalid head size = {sample['scale']}. Putting original sample!")
20+
if sample["scale"] <= 0.0: # TODO: actual decent validation?
21+
print(f"Error: invalid head size = {sample['scale']}. Putting original sample!")
2222
generated_sample = make_sample_for_passthrough(sample)
2323
else:
2424
generated_sample = augment_eyes_only(prob_closed_eyes, rng, sample)
25-
writer.write(sample['name'], generated_sample)
25+
writer.write(sample["name"], generated_sample)
2626

2727

28-
if __name__ == '__main__':
28+
if __name__ == "__main__":
2929
parser = argparse.ArgumentParser("Only Eye Augmentation")
3030
parser.add_argument("aflw2k3d", type=str, help="zip file")
3131
parser.add_argument("outputfilename", type=str, help="hdf5 file")
32-
parser.add_argument("-n", help="subset of n samples", type=int, default=1<<32)
33-
parser.add_argument("--prob-closed-eyes", type=float, default=0., help="probability for closing eyes (between 0 and 1)")
32+
parser.add_argument("-n", help="subset of n samples", type=int, default=1 << 32)
33+
parser.add_argument(
34+
"--prob-closed-eyes",
35+
type=float,
36+
default=0.0,
37+
help="probability for closing eyes (between 0 and 1)",
38+
)
3439
args = parser.parse_args()
35-
if not (args.outputfilename.lower().endswith('.h5') or args.outputfilename.lower().endswith('.hdf5')):
36-
raise ValueError("outputfilename must have hdf5 filename extension")
37-
main(args.aflw2k3d, args.outputfilename, args.n, prob_closed_eyes=args.prob_closed_eyes)
40+
if not (
41+
args.outputfilename.lower().endswith(".h5") or args.outputfilename.lower().endswith(".hdf5")
42+
):
43+
raise ValueError("outputfilename must have hdf5 filename extension")
44+
main(args.aflw2k3d, args.outputfilename, args.n, prob_closed_eyes=args.prob_closed_eyes)

scripts/dsjoin.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def _batched_copy(dst, src, dst_offset=0):
1515
n = min(dst.shape[0] + dst_offset, src.shape[0])
1616
for a in range(0, n, bs):
1717
b = min(n, a + bs)
18-
dst[a + dst_offset : b + dst_offset, ...] = src[a:b, ...] # Buffer in memory, then write. This works ...
18+
dst[a + dst_offset : b + dst_offset, ...] = src[
19+
a:b, ...
20+
] # Buffer in memory, then write. This works ...
1921

2022

2123
def concatenating_join(name1: str, items: Sequence[h5py.Dataset], fout: h5py.File):
@@ -24,7 +26,9 @@ def concatenating_join(name1: str, items: Sequence[h5py.Dataset], fout: h5py.Fil
2426
N = sum(sizes)
2527
print(f"Copying {name1}: {sizes} items of type {first.dtype}")
2628

27-
dst = fout.create_dataset_like(name1, first, shape=(N, *first.shape[1:]), maxshape=(N,) + first.shape[1:])
29+
dst = fout.create_dataset_like(
30+
name1, first, shape=(N, *first.shape[1:]), maxshape=(N,) + first.shape[1:]
31+
)
2832
assert all(list(first.attrs.items()) == list(ds.attrs.items()) for ds in items)
2933
copy_attributes(first, dst)
3034
try:

0 commit comments

Comments
 (0)