Skip to content
This repository was archived by the owner on Jul 30, 2024. It is now read-only.

Commit a14357d

Browse files
author
Krishna Murthy
committed
Lint
Signed-off-by: Krishna Murthy <[email protected]>
1 parent 6189b96 commit a14357d

File tree

6 files changed

+39
-22
lines changed

6 files changed

+39
-22
lines changed

eval_nerf.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,15 @@
99
import yaml
1010
from tqdm import tqdm
1111

12-
from nerf import (CfgNode, get_ray_bundle, load_blender_data, load_llff_data,
13-
models, get_embedding_function, run_one_iter_of_nerf)
12+
from nerf import (
13+
CfgNode,
14+
get_ray_bundle,
15+
load_blender_data,
16+
load_llff_data,
17+
models,
18+
get_embedding_function,
19+
run_one_iter_of_nerf,
20+
)
1421

1522

1623
def cast_to_image(tensor, dataset_type):
@@ -85,7 +92,7 @@ def main():
8592
encode_position_fn = get_embedding_function(
8693
num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
8794
include_input=cfg.models.coarse.include_input_xyz,
88-
log_sampling=cfg.models.coarse.log_sampling_xyz
95+
log_sampling=cfg.models.coarse.log_sampling_xyz,
8996
)
9097

9198
encode_direction_fn = None
@@ -174,7 +181,9 @@ def main():
174181
times_per_image.append(time.time() - start)
175182
if configargs.savedir:
176183
savefile = os.path.join(configargs.savedir, f"{i:04d}.png")
177-
imageio.imwrite(savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower()))
184+
imageio.imwrite(
185+
savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())
186+
)
178187
if configargs.save_disparity_image:
179188
savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png")
180189
imageio.imwrite(savefile, cast_to_disparity_image(disp))

nerf/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(
136136
use_viewdirs=True,
137137
):
138138
super(PaperNeRFModel, self).__init__()
139-
139+
140140
include_input_xyz = 3 if include_input_xyz else 0
141141
include_input_dir = 3 if include_input_dir else 0
142142
self.dim_xyz = include_input_xyz + 2 * 3 * num_encoding_fn_xyz
@@ -161,7 +161,7 @@ def __init__(
161161
self.relu = torch.nn.functional.relu
162162

163163
def forward(self, x):
164-
xyz, dirs = x[..., :self.dim_xyz], x[..., self.dim_xyz:]
164+
xyz, dirs = x[..., : self.dim_xyz], x[..., self.dim_xyz :]
165165
for i in range(8):
166166
if i == 4:
167167
x = self.layers_xyz[i](torch.cat((xyz, x), -1))

nerf/nerf_helpers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,20 @@ def positional_encoding(
130130
encoding = [tensor] if include_input else []
131131
frequency_bands = None
132132
if log_sampling:
133-
frequency_bands = 2. ** torch.linspace(
134-
0., num_encoding_functions - 1, num_encoding_functions,
135-
dtype=tensor.dtype, device=tensor.device,
133+
frequency_bands = 2.0 ** torch.linspace(
134+
0.0,
135+
num_encoding_functions - 1,
136+
num_encoding_functions,
137+
dtype=tensor.dtype,
138+
device=tensor.device,
136139
)
137140
else:
138141
frequency_bands = torch.linspace(
139-
2. ** 0., 2. ** (num_encoding_functions - 1), num_encoding_functions,
140-
dtype=tensor.dtype, device=tensor.device
142+
2.0 ** 0.0,
143+
2.0 ** (num_encoding_functions - 1),
144+
num_encoding_functions,
145+
dtype=tensor.dtype,
146+
device=tensor.device,
141147
)
142148

143149
for freq in frequency_bands:

nerf/train_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,16 @@ def run_one_iter_of_nerf(
180180
for batch in batches
181181
]
182182
synthesized_images = list(zip(*pred))
183-
synthesized_images = [torch.cat(image, dim=0) if image[0] is not None else (None) for image in synthesized_images]
183+
synthesized_images = [
184+
torch.cat(image, dim=0) if image[0] is not None else (None)
185+
for image in synthesized_images
186+
]
184187
if mode == "validation":
185188
synthesized_images = [
186189
image.view(shape) if image is not None else None
187190
for (image, shape) in zip(synthesized_images, restore_shapes)
188191
]
189-
192+
190193
# Returns rgb_coarse, disp_coarse, acc_coarse, rgb_fine, disp_fine, acc_fine
191194
# (assuming both the coarse and fine networks are used).
192195
if model_fine:

tiny_nerf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import torch
77
from tqdm import tqdm, trange
88

9-
from nerf import (cumprod_exclusive, get_minibatches, get_ray_bundle,
10-
positional_encoding)
9+
from nerf import cumprod_exclusive, get_minibatches, get_ray_bundle, positional_encoding
1110

1211

1312
def compute_query_points_from_rays(

train_nerf.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from torch.utils.tensorboard import SummaryWriter
1111
from tqdm import tqdm, trange
1212

13-
from nerf import (CfgNode, get_ray_bundle, img2mse, load_blender_data,
14-
load_llff_data, meshgrid_xy, models, mse2psnr,
15-
get_embedding_function, run_one_iter_of_nerf)
13+
from nerf import (CfgNode, get_embedding_function, get_ray_bundle, img2mse,
14+
load_blender_data, load_llff_data, meshgrid_xy, models,
15+
mse2psnr, run_one_iter_of_nerf)
1616

1717

1818
def main():
@@ -63,7 +63,7 @@ def main():
6363
H, W = int(H), int(W)
6464
hwf = [H, W, focal]
6565
if cfg.nerf.train.white_background:
66-
images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:])
66+
images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:])
6767
elif cfg.dataset.type.lower() == "llff":
6868
images, poses, bds, render_poses, i_test = load_llff_data(
6969
cfg.dataset.basedir, factor=cfg.dataset.downsample_factor
@@ -104,7 +104,7 @@ def main():
104104
include_input=cfg.models.coarse.include_input_xyz,
105105
log_sampling=cfg.models.coarse.log_sampling_xyz,
106106
)
107-
107+
108108
encode_direction_fn = None
109109
if cfg.models.coarse.use_viewdirs:
110110
encode_direction_fn = get_embedding_function(
@@ -250,7 +250,7 @@ def main():
250250
rgb_fine[..., :3], target_ray_values[..., :3]
251251
)
252252
# loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3])
253-
loss = 0.
253+
loss = 0.0
254254
# if fine_loss is not None:
255255
# loss = fine_loss
256256
# else:
@@ -337,7 +337,7 @@ def main():
337337
)
338338
target_ray_values = img_target
339339
coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3])
340-
loss, fine_loss = 0., 0.
340+
loss, fine_loss = 0.0, 0.0
341341
if rgb_fine is not None:
342342
fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3])
343343
loss = fine_loss

0 commit comments

Comments
 (0)