Skip to content

Commit 69b585c

Browse files
committed
feat: support pcc compression for rd benchmark
1 parent 329997e commit 69b585c

File tree

11 files changed

+1000
-21
lines changed

11 files changed

+1000
-21
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/bin/bash
2+
3+
#!/bin/bash
4+
5+
# Define the list of GPU IDs to use
6+
GPU_IDS=(0 1 2 3) # You can modify this list, e.g., GPU_IDS=(0 2 5 7)
7+
8+
# Function to run a single experiment
9+
run_experiment() {
10+
local gpu_id=$1
11+
local rp_id=$2
12+
13+
echo "Starting experiment rp${rp_id} on GPU ${gpu_id}"
14+
15+
CUDA_VISIBLE_DEVICES=${gpu_id} python compress_ply_sequence.py pcc_compression_rp${rp_id} \
16+
--data_factor 1 \
17+
--ply_dir /work/Users/lisicheng/Dataset/GSC_splats/m71763_bartender_stable/track \
18+
--data_dir /work/Users/lisicheng/Dataset/GSC_splats/m71763_bartender_stable/colmap_data \
19+
--result_dir results/mpeg150/pcc_anchor/rp${rp_id} \
20+
--frame_num 16 \
21+
--lpips_net vgg \
22+
--no-normalize_world_space \
23+
--scene_type GSC \
24+
--test_view_id 9 11 \
25+
26+
echo "Experiment rp${rp_id} started on GPU ${gpu_id}"
27+
}
28+
29+
# Check if the number of GPUs is sufficient
30+
if [ ${#GPU_IDS[@]} -lt 4 ]; then
31+
echo "Warning: Number of GPUs is less than the number of experiments, some experiments will be skipped"
32+
fi
33+
34+
# Launch experiments in parallel
35+
for i in {0..3}; do
36+
if [ $i -lt ${#GPU_IDS[@]} ]; then
37+
run_experiment ${GPU_IDS[$i]} $i &
38+
echo "Launched experiment rp${i} on GPU ${GPU_IDS[$i]} in background"
39+
else
40+
echo "Skipping experiment rp${i} due to insufficient GPUs"
41+
fi
42+
done
43+
44+
# Wait for all background processes to complete
45+
wait
46+
47+
echo "All experiments completed"

examples/compress_ply_sequence.py

Lines changed: 169 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
from gsplat.compression_simulation import CompressionSimulation
5151
from gsplat.compression_simulation.entropy_model import Entropy_factorized_optimized_refactor, Entropy_gaussian
5252

53+
from helper.ges_tm.pre_process_gaussian import load_ply_and_quant
54+
from helper.ges_tm.post_process_gaussian import inverse_load_ply
5355

5456

5557
@dataclass
@@ -89,6 +91,19 @@ def default_qp_values() -> Dict[str, Union[int, Dict[str, Any]]]:
8991

9092
@dataclass
9193
class CompressionConfig:
94+
rate_point: str = "rp0"
95+
96+
@dataclass
97+
class PCCompressionConfig(CompressionConfig):
98+
pcc_config_filename: str = "encoder_r05.cfg"
99+
100+
def to_dict(self) -> Dict[str, Any]:
101+
return {
102+
"pcc_config_filename": self.pcc_config_filename
103+
}
104+
105+
@dataclass
106+
class VideoCompressionConfig(CompressionConfig):
92107
# Use PLAS sort in compression or not
93108
use_sort: bool = True
94109
# Verbose or not
@@ -133,7 +148,7 @@ class Config:
133148
# qp: Optional[int] = None
134149
# Configuration for compression methods
135150
compression_cfg: CompressionConfig = field(
136-
default_factory=CompressionConfig
151+
default_factory=VideoCompressionConfig
137152
)
138153

139154
# Enable profiler
@@ -298,6 +313,8 @@ class Config:
298313
data_dir: str = ""
299314
# frame num
300315
frame_num: int = 1
316+
# anchor type
317+
anchor_type: Literal["video", "pcc"] = "video"
301318

302319
class Runner:
303320
def __init__(
@@ -347,17 +364,17 @@ def __init__(
347364
# load dataset
348365
self.trainset_list, self.valset_list = self.set_up_datasets(cfg.data_dir, cfg.frame_num, cfg)
349366

350-
compression_cfg = cfg.compression_cfg.to_dict()
367+
self.compression_cfg = cfg.compression_cfg.to_dict()
351368
if cfg.compression == "seq_hevc":
352-
self.compression_method = SeqHevcCompression(**compression_cfg)
369+
self.compression_method = SeqHevcCompression(**self.compression_cfg)
353370

354371
def load_ply_sequences(
355372
self, ply_dir: str, frame_num: int
356373
) -> List[torch.nn.ParameterDict]:
357-
self.filename_list = sorted(glob.glob(os.path.join(ply_dir, "*.ply")))
374+
self.ply_filename_list = sorted(glob.glob(os.path.join(ply_dir, "*.ply")))
358375

359376
splats_list = []
360-
for filename in tqdm.tqdm(self.filename_list[:frame_num], desc="Loading .ply file"):
377+
for filename in tqdm.tqdm(self.ply_filename_list[:frame_num], desc="Loading .ply file"):
361378
splats = load_ply(filename)
362379
splats_list.append(splats.to("cuda"))
363380

@@ -396,11 +413,11 @@ def set_up_datasets(
396413
return trainset_list, valset_list
397414

398415
def compress(self, ):
399-
"""Entry for running compression."""
400-
print("Running compression...")
416+
"""Entry for running video anchor compression."""
417+
print("Running video anchor compression...")
401418
world_rank = self.world_rank
402419

403-
compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}"
420+
compress_dir = f"{cfg.result_dir}/compression"
404421

405422
if os.path.exists(compress_dir):
406423
shutil.rmtree(compress_dir)
@@ -409,8 +426,6 @@ def compress(self, ):
409426
splats_videos = self.compression_method.reorganize(self.splats_list)
410427
self.compression_method.compress(compress_dir)
411428
video_splats_c = self.compression_method.decompress(compress_dir)
412-
413-
# splats_list_c = self.compression_method.deorganize(splats_videos)
414429
splats_list_c = self.compression_method.deorganize(video_splats_c)
415430

416431
for splats, splats_c in zip(self.splats_list, splats_list_c):
@@ -419,6 +434,93 @@ def compress(self, ):
419434

420435
self.eval(stage="compress")
421436

437+
def pcc_compress(self, ):
438+
"""Entry for running pc anchor compression."""
439+
print("Running pc anchor compression...")
440+
import subprocess
441+
442+
compress_dir = f"{cfg.result_dir}/compression"
443+
intermediate_dir = f"{cfg.result_dir}/intermediate"
444+
log_dir = f"{cfg.result_dir}/log"
445+
rec_dir = f"{cfg.result_dir}/rec"
446+
447+
if os.path.exists(compress_dir) or os.path.exists(intermediate_dir):
448+
shutil.rmtree(compress_dir)
449+
shutil.rmtree(intermediate_dir)
450+
shutil.rmtree(log_dir)
451+
# shutil.rmtree(rec_dir)
452+
os.makedirs(compress_dir, exist_ok=True)
453+
os.makedirs(intermediate_dir, exist_ok=True)
454+
os.makedirs(log_dir, exist_ok=True)
455+
os.makedirs(rec_dir, exist_ok=True)
456+
457+
for f_id, ply_file in enumerate(self.ply_filename_list[:self.frame_num]):
458+
# preprocess: fixed-point quantization
459+
temp_frame_dir = os.path.join(intermediate_dir, f"frame{f_id:03d}")
460+
os.makedirs(temp_frame_dir, exist_ok=True)
461+
load_ply_and_quant(ply_file, temp_frame_dir)
462+
463+
# encode
464+
print(f"Encode frame{f_id:03d} via GeS-TM.")
465+
quant_ply_file = temp_frame_dir + f"/quant_splats.ply"
466+
encoded_bin_file = compress_dir + f"/frame{f_id:03d}.bin"
467+
encode_cmd = [
468+
'./helper/ges_tm/tmc3',
469+
'-c',
470+
f"./helper/ges_tm/{self.compression_cfg['pcc_config_filename']}",
471+
f'--uncompressedDataPath={quant_ply_file}',
472+
f'--compressedStreamPath={encoded_bin_file}'
473+
]
474+
475+
encode_log_file = os.path.join(log_dir, f"frame{f_id:03d}_encode_log.txt")
476+
with open(encode_log_file, 'w') as log_file:
477+
result = subprocess.run(encode_cmd,
478+
capture_output=True,
479+
text=True, # output text rather than byte
480+
)
481+
log_file.write(result.stdout)
482+
log_file.write(result.stderr)
483+
484+
# decode
485+
print(f"Decode frame{f_id:03d} via GeS-TM.")
486+
decoded_ply_file = temp_frame_dir + f"/decoded_quant_splats.ply"
487+
decode_cmd = [
488+
'./helper/ges_tm/tmc3',
489+
'-c',
490+
'./helper/ges_tm/decoder.cfg',
491+
f'--compressedStreamPath={encoded_bin_file}',
492+
f'--reconstructedDataPath={decoded_ply_file}'
493+
]
494+
495+
decode_log_file = os.path.join(log_dir, f"frame{f_id:03d}_decode_log.txt")
496+
with open(decode_log_file, 'w') as log_file:
497+
498+
start_time = time.time()
499+
result = subprocess.run(decode_cmd,
500+
capture_output=True,
501+
text=True, # output text rather than byte
502+
)
503+
504+
end_time = time.time()
505+
elapsed_time = end_time - start_time
506+
507+
log_file.write(result.stdout)
508+
log_file.write(f"\nExecution time of decoding: {elapsed_time:.3f} seconds\n")
509+
510+
print(f"Execution time of decoding: {elapsed_time:.3f} seconds")
511+
512+
# postprocess
513+
output_filename = os.path.join(rec_dir, f"frame{f_id:03d}.ply")
514+
inverse_load_ply(decoded_ply_file, output_filename)
515+
516+
splats_list_c = self.load_ply_sequences(rec_dir, self.frame_num)
517+
518+
for splats, splats_c in zip(self.splats_list, splats_list_c):
519+
for k in splats.keys():
520+
splats[k].data = splats_c[k].to(self.device)
521+
522+
self.eval(stage="compress")
523+
422524
def rasterize_splats(
423525
self,
424526
camtoworlds: Tensor,
@@ -583,7 +685,7 @@ def format_size(size_bytes):
583685
return f"{size_bytes/(1024**3):.2f} GB"
584686

585687
# rate summary
586-
directory_path = os.path.join(self.cfg.result_dir, "compression", "rank0")
688+
directory_path = os.path.join(self.cfg.result_dir, "compression")
587689

588690
# Check if directory exists
589691
if not os.path.exists(directory_path):
@@ -646,29 +748,75 @@ def stack_render_img_to_vid(self):
646748

647749
def main(local_rank: int, world_rank, world_size: int, cfg: Config):
648750
runner = Runner(local_rank, world_rank, world_size, cfg)
649-
751+
650752
runner.eval()
651-
runner.compress()
753+
if cfg.anchor_type == "video":
754+
runner.compress()
755+
elif cfg.anchor_type == "pcc":
756+
runner.pcc_compress()
757+
else:
758+
raise NotImplementedError(f"{cfg.anchor_type} Anchor has not been implemented.")
759+
652760
runner.summary()
653761
runner.stack_render_img_to_vid()
654-
655762

656763
if __name__ == "__main__":
657764
configs = {
765+
"pcc_compression_debug":(
766+
"Use PCCompression.",
767+
Config(
768+
anchor_type="pcc",
769+
compression_cfg=PCCompressionConfig()
770+
)
771+
),
772+
"pcc_compression_rp0":(
773+
"Use PCCompression.",
774+
Config(
775+
anchor_type="pcc",
776+
compression_cfg=PCCompressionConfig(
777+
pcc_config_filename="encoder_r05.cfg"
778+
)
779+
)
780+
),
781+
"pcc_compression_rp1":(
782+
"Use PCCompression.",
783+
Config(
784+
anchor_type="pcc",
785+
compression_cfg=PCCompressionConfig(
786+
pcc_config_filename="encoder_r06.cfg"
787+
)
788+
)
789+
),
790+
"pcc_compression_rp2":(
791+
"Use PCCompression.",
792+
Config(
793+
anchor_type="pcc",
794+
compression_cfg=PCCompressionConfig(
795+
pcc_config_filename="encoder_r07.cfg"
796+
)
797+
)
798+
),
799+
"pcc_compression_rp3":(
800+
"Use PCCompression.",
801+
Config(
802+
anchor_type="pcc",
803+
compression_cfg=PCCompressionConfig(
804+
pcc_config_filename="encoder_r08.cfg"
805+
)
806+
)
807+
),
658808
"x265_compression_debug":(
659809
"Use HevcCompression.",
660810
Config(
661811
compression="seq_hevc",
662-
compression_cfg=CompressionConfig(
663-
# qp=4,
664-
n_clusters=8192)
812+
compression_cfg=VideoCompressionConfig(n_clusters=8192)
665813
)
666814
),
667815
"x265_compression_rp0": (
668816
"Use HevcCompression.",
669817
Config(
670818
compression="seq_hevc",
671-
compression_cfg=CompressionConfig(
819+
compression_cfg=VideoCompressionConfig(
672820
qp={
673821
"means": -1,
674822
"opacities": 4,
@@ -688,7 +836,7 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
688836
"Use HevcCompression.",
689837
Config(
690838
compression="seq_hevc",
691-
compression_cfg=CompressionConfig(
839+
compression_cfg=VideoCompressionConfig(
692840
qp={
693841
"means": -1,
694842
"opacities": 4,
@@ -708,7 +856,7 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
708856
"Use HevcCompression.",
709857
Config(
710858
compression="seq_hevc",
711-
compression_cfg=CompressionConfig(
859+
compression_cfg=VideoCompressionConfig(
712860
qp={
713861
"means": -1,
714862
"opacities": 10,
@@ -728,7 +876,7 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
728876
"Use HevcCompression.",
729877
Config(
730878
compression="seq_hevc",
731-
compression_cfg=CompressionConfig(
879+
compression_cfg=VideoCompressionConfig(
732880
qp={
733881
"means": -1,
734882
"opacities": 16,

examples/helper/ges_tm/decoder.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
mode: 1
2+
outputBinaryPly: 1

0 commit comments

Comments
 (0)