Skip to content

Commit 97a33e4

Browse files
authored
measure policy weight update (#286)
1 parent 1f27b54 commit 97a33e4

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/forge/actors/policy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,8 @@ async def update(self, version: int):
654654
logger.debug(f"{matching_keys=}")
655655
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
656656
loaded_weights = set()
657-
start = time.perf_counter()
657+
t = Tracer("policy_worker_perf/update", timer="gpu")
658+
t.start()
658659
# Entire state dict is stored in a single DCP handle
659660
if dcp_whole_state_dict_key in matching_keys:
660661
logger.info(
@@ -677,9 +678,7 @@ async def update(self, version: int):
677678
loaded = model.load_weights([(name, param)])
678679
del param
679680
loaded_weights.update(loaded)
680-
logger.info(
681-
f"[PolicyWorker::update] Updated {len(loaded_weights)} parameters, took {time.perf_counter() - start} seconds"
682-
)
681+
t.stop()
683682
logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}")
684683

685684
@endpoint

0 commit comments

Comments
 (0)