@@ -44,17 +44,12 @@ def __init__(self):
4444 torch .profiler .ProfilerActivity .CUDA ,
4545 ]
4646
47- # 基础配置
48- self .wait = 4900 # 开始记录前等待的步数
49- self .warmup = 50 # 预热步数
50- self .active = 30_000 # 实际分析的步数
51- # self.repeat = 2 # 重复次数
52- # self.skip_first = 10 # 跳过前N步(可选)
47+ self .wait = 4900
48+ self .warmup = 50
49+ self .active = 30_000
5350
54- # 创建schedule
5551 self .schedule = self ._create_schedule ()
5652
57- # 其他profiler设置
5853 self .on_trace_ready = torch .profiler .tensorboard_trace_handler ('./log/profiler' )
5954 self .record_shapes = True
6055 self .profile_memory = True
@@ -66,12 +61,9 @@ def _create_schedule(self):
6661 wait = self .wait ,
6762 warmup = self .warmup ,
6863 active = self .active ,
69- # repeat=self.repeat,
70- # skip_first=self.skip_first
7164 )
7265
7366 def update_schedule (self , ** kwargs ):
74- """动态更新schedule参数"""
7567 for key , value in kwargs .items ():
7668 if hasattr (self , key ):
7769 setattr (self , key , value )
@@ -329,19 +321,15 @@ def __init__(self, cfg: Config) -> None:
329321 batch_size = 1 ,
330322 shuffle = True ,
331323 num_workers = 16 ,
332- # persistent_workers=True,
333324 pin_memory = True ,
334- # collate_fn=collate_fn
335325 )
336326
337327 self .testloader = torch .utils .data .DataLoader (
338328 self .testset ,
339329 batch_size = 1 ,
340330 shuffle = False ,
341331 num_workers = 8 ,
342- # persistent_workers=True,
343332 pin_memory = True ,
344- # collate_fn=collate_fn
345333 )
346334 # scene scale
347335 self .scene_scale = self .parser .scene_scale * 1.1 * cfg .global_scale
@@ -405,7 +393,6 @@ def __init__(self, cfg: Config) -> None:
405393 self .strategy_state = self .strategy .initialize_state (scene_scale = self .scene_scale )
406394
407395 # Compression Strategy
408- # TODO Compression Strategy should proceed here, according to GSplat
409396 self .compression_method = None
410397 if cfg .compression is not None :
411398 if cfg .compression == "png" :
@@ -416,8 +403,6 @@ def __init__(self, cfg: Config) -> None:
416403 raise ValueError (f"Unknown compression strategy: { cfg .compression } " )
417404
418405 if cfg .compression_sim :
419- # TODO: bad impl.
420- # cap_max = cfg.strategy.cap_max if cfg.strategy.cap_max is not None else None
421406 self .compression_sim_method = STGCompressionSimulation (cfg .quantization_sim_type ,
422407 cfg .entropy_model_opt ,
423408 cfg .entropy_steps ,
@@ -789,8 +774,6 @@ def train(self):
789774 optimizer .zero_grad (set_to_none = True )
790775
791776 self .step_profiler ()
792-
793- # torch.cuda.empty_cache()
794777
795778 # eval the full set
796779 if step in [i - 1 for i in cfg .eval_steps ]:
0 commit comments