Skip to content

Commit ff1a564

Browse files
committed
misc
1 parent 072ea3d commit ff1a564

File tree

1 file changed

+3
-20
lines changed

1 file changed

+3
-20
lines changed

examples/simple_trainer_dyngs.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,12 @@ def __init__(self):
4444
torch.profiler.ProfilerActivity.CUDA,
4545
]
4646

47-
# 基础配置
48-
self.wait = 4900 # 开始记录前等待的步数
49-
self.warmup = 50 # 预热步数
50-
self.active = 30_000 # 实际分析的步数
51-
# self.repeat = 2 # 重复次数
52-
# self.skip_first = 10 # 跳过前N步(可选)
47+
self.wait = 4900
48+
self.warmup = 50
49+
self.active = 30_000
5350

54-
# 创建schedule
5551
self.schedule = self._create_schedule()
5652

57-
# 其他profiler设置
5853
self.on_trace_ready = torch.profiler.tensorboard_trace_handler('./log/profiler')
5954
self.record_shapes = True
6055
self.profile_memory = True
@@ -66,12 +61,9 @@ def _create_schedule(self):
6661
wait=self.wait,
6762
warmup=self.warmup,
6863
active=self.active,
69-
# repeat=self.repeat,
70-
# skip_first=self.skip_first
7164
)
7265

7366
def update_schedule(self, **kwargs):
74-
"""动态更新schedule参数"""
7567
for key, value in kwargs.items():
7668
if hasattr(self, key):
7769
setattr(self, key, value)
@@ -329,19 +321,15 @@ def __init__(self, cfg: Config) -> None:
329321
batch_size=1,
330322
shuffle=True,
331323
num_workers=16,
332-
# persistent_workers=True,
333324
pin_memory=True,
334-
# collate_fn=collate_fn
335325
)
336326

337327
self.testloader = torch.utils.data.DataLoader(
338328
self.testset,
339329
batch_size=1,
340330
shuffle=False,
341331
num_workers=8,
342-
# persistent_workers=True,
343332
pin_memory=True,
344-
# collate_fn=collate_fn
345333
)
346334
# scene scale
347335
self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale
@@ -405,7 +393,6 @@ def __init__(self, cfg: Config) -> None:
405393
self.strategy_state = self.strategy.initialize_state(scene_scale=self.scene_scale)
406394

407395
# Compression Strategy
408-
# TODO Compression Strategy should proceed here, according to GSplat
409396
self.compression_method = None
410397
if cfg.compression is not None:
411398
if cfg.compression == "png":
@@ -416,8 +403,6 @@ def __init__(self, cfg: Config) -> None:
416403
raise ValueError(f"Unknown compression strategy: {cfg.compression}")
417404

418405
if cfg.compression_sim:
419-
# TODO: bad impl.
420-
# cap_max = cfg.strategy.cap_max if cfg.strategy.cap_max is not None else None
421406
self.compression_sim_method = STGCompressionSimulation(cfg.quantization_sim_type,
422407
cfg.entropy_model_opt,
423408
cfg.entropy_steps,
@@ -789,8 +774,6 @@ def train(self):
789774
optimizer.zero_grad(set_to_none=True)
790775

791776
self.step_profiler()
792-
793-
# torch.cuda.empty_cache()
794777

795778
# eval the full set
796779
if step in [i - 1 for i in cfg.eval_steps]:

0 commit comments

Comments
 (0)