Skip to content

Commit 8fd0512

Browse files
committed
[fix] formatting
1 parent 8258ba2 commit 8fd0512

34 files changed

+202
-203
lines changed

compressai_vision/codecs/encdec_utils/png_yuv.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def __call__(self, input: Dict, file_prefix: str):
8989
images_in_folder = len(list(parent.glob(ext)))
9090
nb_frames = input["last_frame"] - input["frame_skip"]
9191

92-
assert (
93-
images_in_folder == nb_frames
94-
), f"input folder contains {images_in_folder} images, {nb_frames} were expected"
92+
assert images_in_folder == nb_frames, (
93+
f"input folder contains {images_in_folder} images, {nb_frames} were expected"
94+
)
9595

9696
input_info = [
9797
"-pattern_type",
@@ -122,9 +122,9 @@ def __call__(self, input: Dict, file_prefix: str):
122122

123123
# Use existing YUV (if found and indicated for use):
124124
if self.use_yuv:
125-
assert (
126-
yuv_file is not None
127-
), "Parameter 'use_yuv' set True but YUV file not found."
125+
assert yuv_file is not None, (
126+
"Parameter 'use_yuv' set True but YUV file not found."
127+
)
128128
size = yuv_file.stat().st_size
129129
bytes_per_luma_sample = {"yuv420p": 1.5}[chroma_format]
130130
bytes_per_sample = (input_bitdepth + 7) >> 3
@@ -135,9 +135,9 @@ def __call__(self, input: Dict, file_prefix: str):
135135
* bytes_per_sample
136136
* nb_frames
137137
)
138-
assert (
139-
size == expected_size
140-
), f"YUV found for input but expected size of {expected_size} bytes differs from actual size of {size} bytes"
138+
assert size == expected_size, (
139+
f"YUV found for input but expected size of {expected_size} bytes differs from actual size of {size} bytes"
140+
)
141141
shutil.copy(yuv_file, yuv_in_path)
142142
print(f"Using pre-existing YUV file: {yuv_file}")
143143
return (yuv_in_path, nb_frames, frame_width, frame_height, file_prefix)
@@ -204,9 +204,9 @@ def __call__(
204204
frame_width = video_info["width"]
205205
frame_height = video_info["height"]
206206

207-
assert (
208-
"420" in video_info["format"].value
209-
), f"Only support yuv420, but got {video_info['format']}"
207+
assert "420" in video_info["format"].value, (
208+
f"Only support yuv420, but got {video_info['format']}"
209+
)
210210
pix_fmt_suffix = "10le" if video_info["bitdepth"] == 10 else ""
211211
chroma_format = "yuv420p"
212212

compressai_vision/codecs/sic_sfu2022.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,13 @@ def __init__(self, device: str, **kwargs):
151151
# root_url = "https://dspub.blob.core.windows.net/compressai/sic_sfu2022"
152152

153153
self.target_tlayer = int(kwargs["target_task_layer"])
154-
assert (
155-
self.num_tasks == 2 or self.num_tasks == 3
156-
), f"SIC_SFU2023 supports only 2 or 3 task layers, but got {self.num_tasks}"
157-
assert (
158-
self.target_tlayer < self.num_tasks
159-
), f"target task layer must be lower than the number of tasks, \
154+
assert self.num_tasks == 2 or self.num_tasks == 3, (
155+
f"SIC_SFU2023 supports only 2 or 3 task layers, but got {self.num_tasks}"
156+
)
157+
assert self.target_tlayer < self.num_tasks, (
158+
f"target task layer must be lower than the number of tasks, \
160159
but got {self.target_tlayer} < {self.num_tasks}"
160+
)
161161

162162
self.trg_vmodel = self.vmodels[self.target_tlayer]
163163

@@ -572,9 +572,9 @@ def compress(self, x, target_layer=0, feature_only=False):
572572
"models (the entropy coder is run sequentially on CPU)."
573573
)
574574

575-
assert (
576-
target_layer < self.NUM_LAYERS
577-
), f"Got the target layer {target_layer}, but should be less than {self.NUM_LAYERS}"
575+
assert target_layer < self.NUM_LAYERS, (
576+
f"Got the target layer {target_layer}, but should be less than {self.NUM_LAYERS}"
577+
)
578578

579579
y = self.g_a(x)
580580
z = self.h_a(y)

compressai_vision/codecs/std_codecs.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def encode(
436436
nb_frames, frame_height, frame_width = frames.size()
437437
input_bitdepth = self.enc_cfgs["input_bitdepth"]
438438
chroma_format = self.enc_cfgs["chroma_format"]
439-
file_prefix = f"{file_prefix}_{frame_width}x{frame_height}_{self.frame_rate }fps_{input_bitdepth}bit_p{chroma_format}"
439+
file_prefix = f"{file_prefix}_{frame_width}x{frame_height}_{self.frame_rate}fps_{input_bitdepth}bit_p{chroma_format}"
440440
yuv_in_path = f"{file_prefix}_input.yuv"
441441

442442
self.yuvio.setWriter(
@@ -480,9 +480,9 @@ def encode(
480480
for partial in list_of_bitstreams:
481481
Path(partial).unlink()
482482

483-
assert Path(
484-
bitstream_path
485-
).is_file(), f"bitstream {bitstream_path} was not created"
483+
assert Path(bitstream_path).is_file(), (
484+
f"bitstream {bitstream_path} was not created"
485+
)
486486

487487
if not remote_inference:
488488
inner_codec_bitstream = load_bitstream(bitstream_path)
@@ -601,12 +601,12 @@ def decode(
601601
for file_path in sorted(Path(dec_path).glob(f"*{file_prefix}*.png")):
602602
rec_frames.append(str(file_path))
603603

604-
assert (
605-
file_prefix in rec_frames[0]
606-
), f"Can't find a correct filename with {file_prefix} in {dec_path}"
607-
assert (
608-
len(rec_frames) == 1
609-
), f"Number of retrieved file must be 1, but got {len(rec_frames)}"
604+
assert file_prefix in rec_frames[0], (
605+
f"Can't find a correct filename with {file_prefix} in {dec_path}"
606+
)
607+
assert len(rec_frames) == 1, (
608+
f"Number of retrieved file must be 1, but got {len(rec_frames)}"
609+
)
610610

611611
conversion_time = 0
612612
output = {"file_names": rec_frames}
@@ -1090,9 +1090,7 @@ def add_descriptor_modes(descriptors, descriptor_mode):
10901090
"vcm_ctc": ["load", "UsingDescriptor", "load", "load"],
10911091
"load": ["load", "UsingDescriptor", "load", "load"],
10921092
"generate": ["save", "GeneratingDescriptor", "save", "save"],
1093-
}[
1094-
descriptor_mode
1095-
]
1093+
}[descriptor_mode]
10961094
items = str(bitstream_path).split("/")
10971095
dataset = {
10981096
"SFUHW": "SFU",

compressai_vision/codecs/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ def reshape_frame_to_feature_pyramid(
257257
"""reshape a frame of channels into the feature pyramid"""
258258

259259
assert isinstance(x, (Tensor, Dict))
260-
assert (
261-
packing_all_in_one is True
262-
), "packing_all_in_one = False is not supported yet"
260+
assert packing_all_in_one is True, (
261+
"packing_all_in_one = False is not supported yet"
262+
)
263263

264264
top_y = 0
265265
tiled_frames = {}

compressai_vision/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
DataCatalog,
3232
DefaultDataset,
3333
Detectron2Dataset,
34+
SamDataset,
3435
TrackingDataset,
3536
deccode_compressed_rle,
36-
SamDataset,
3737
)
3838
from .utils import get_seq_info
3939

compressai_vision/datasets/image.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454
JDECustomMapper,
5555
LinearMapper,
5656
MMPOSECustomMapper,
57-
YOLOXCustomMapper,
5857
SAMCustomMapper,
58+
YOLOXCustomMapper,
5959
)
6060

6161

@@ -246,9 +246,9 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs):
246246
if kwargs["linear_mapper"] is True:
247247
mapper = LinearMapper()
248248
else:
249-
assert (
250-
kwargs["cfg"] is not None
251-
), "A proper mapper information via cfg must be provided"
249+
assert kwargs["cfg"] is not None, (
250+
"A proper mapper information via cfg must be provided"
251+
)
252252
mapper = DatasetMapper(kwargs["cfg"], False)
253253

254254
self.mapDataset = MapDataset(_dataset, mapper)

compressai_vision/datasets/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737

3838
from jde.utils.datasets import letterbox
3939
from mmpose.structures.bbox import get_warp_matrix
40-
from torchvision import transforms
4140
from torch.nn import functional as F
41+
from torchvision import transforms
4242

4343
__all__ = [
4444
"MMPOSECustomMapper",
@@ -139,9 +139,9 @@ def __call__(self, dataset_dict):
139139
# Read image
140140
org_img = cv2.imread(dataset_dict["file_name"]) # return img in BGR by default
141141

142-
assert (
143-
len(org_img.shape) == 3
144-
), f"detect an input image with 2 chs, {dataset_dict['file_name']}"
142+
assert len(org_img.shape) == 3, (
143+
f"detect an input image with 2 chs, {dataset_dict['file_name']}"
144+
)
145145

146146
img_h, img_w, _ = org_img.shape
147147

@@ -220,9 +220,9 @@ def __call__(self, dataset_dict):
220220
# Read image
221221
org_img = cv2.imread(dataset_dict["file_name"]) # return img in BGR by default
222222

223-
assert (
224-
len(org_img.shape) == 3
225-
), f"detect an input image with 2 chs, {dataset_dict['file_name']}"
223+
assert len(org_img.shape) == 3, (
224+
f"detect an input image with 2 chs, {dataset_dict['file_name']}"
225+
)
226226

227227
dataset_dict["height"], dataset_dict["width"], _ = org_img.shape
228228

@@ -412,5 +412,5 @@ def get_seq_info(seq_info_path):
412412
config.read(seq_info_path)
413413
fps = config["Sequence"]["frameRate"]
414414
total_frame = config["Sequence"]["seqLength"]
415-
name = f'{config["Sequence"]["name"]}_{config["Sequence"]["imWidth"]}x{config["Sequence"]["imHeight"]}_{fps}'
415+
name = f"{config['Sequence']['name']}_{config['Sequence']['imWidth']}x{config['Sequence']['imHeight']}_{fps}"
416416
return name, int(fps), int(total_frame)

compressai_vision/evaluators/evaluators.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -725,9 +725,9 @@ def digest_summary(summary):
725725
return ret
726726

727727
def mot_eval(self):
728-
assert len(self.dataset) == len(
729-
self._predictions
730-
), "Total number of frames are mismatch"
728+
assert len(self.dataset) == len(self._predictions), (
729+
"Total number of frames are mismatch"
730+
)
731731

732732
# skip the very first frame
733733
for gt_frame in self.dataset[1:]:
@@ -818,9 +818,9 @@ def __init__(
818818
assert self.seqinfo_path is not None, "Sequence Information must be provided"
819819

820820
def mot_eval(self):
821-
assert len(self.dataset) == len(
822-
self._predictions
823-
), "Total number of frames are mismatch"
821+
assert len(self.dataset) == len(self._predictions), (
822+
"Total number of frames are mismatch"
823+
)
824824

825825
self._save_all_eval_info(self._predictions)
826826
_pd_pd = self._format_pd_in_motchallenge(self._predictions)
@@ -874,9 +874,9 @@ def __init__(
874874
)
875875

876876
def mot_eval(self):
877-
assert len(self.dataset) == len(
878-
self._predictions
879-
), "Total number of frames are mismatch"
877+
assert len(self.dataset) == len(self._predictions), (
878+
"Total number of frames are mismatch"
879+
)
880880

881881
self._save_all_eval_info(self._predictions)
882882
_pd_pd = self._format_pd_in_motchallenge(self._predictions)

compressai_vision/evaluators/tf_evaluation_utils/np_box_list_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def multi_class_non_max_suppression(boxlist, score_thresh, iou_thresh, max_outpu
310310
elif len(scores.shape) == 2:
311311
if scores.shape[1] is None:
312312
raise ValueError(
313-
"scores field must have statically defined second " "dimension"
313+
"scores field must have statically defined second dimension"
314314
)
315315
else:
316316
raise ValueError("scores field must be of rank 1 or 2")
@@ -551,7 +551,7 @@ def filter_scores_greater_than(boxlist, thresh):
551551
raise ValueError("Scores should have rank 1 or 2")
552552
if len(scores.shape) == 2 and scores.shape[1] != 1:
553553
raise ValueError(
554-
"Scores should have rank 1 or have shape " "consistent with [None, 1]"
554+
"Scores should have rank 1 or have shape consistent with [None, 1]"
555555
)
556556
high_score_indices = np.reshape(np.where(np.greater(scores, thresh)), [-1]).astype(
557557
np.int32

compressai_vision/evaluators/tf_evaluation_utils/np_box_mask_list_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def multi_class_non_max_suppression(
312312
elif len(scores.shape) == 2:
313313
if scores.shape[1] is None:
314314
raise ValueError(
315-
"scores field must have statically defined second " "dimension"
315+
"scores field must have statically defined second dimension"
316316
)
317317
else:
318318
raise ValueError("scores field must be of rank 1 or 2")
@@ -430,7 +430,7 @@ def filter_scores_greater_than(box_mask_list, thresh):
430430
raise ValueError("Scores should have rank 1 or 2")
431431
if len(scores.shape) == 2 and scores.shape[1] != 1:
432432
raise ValueError(
433-
"Scores should have rank 1 or have shape " "consistent with [None, 1]"
433+
"Scores should have rank 1 or have shape consistent with [None, 1]"
434434
)
435435
high_score_indices = np.reshape(np.where(np.greater(scores, thresh)), [-1]).astype(
436436
np.int32

0 commit comments

Comments
 (0)